From 0e8fc16e65e86ea667241f9afd3def3d904d04f1 Mon Sep 17 00:00:00 2001 From: tqchen Date: Mon, 16 Jun 2025 09:27:05 -0400 Subject: [PATCH] [FFI] Introduce FFI reflection support in python This PR brings up new reflection support in python. The new reflection now directly attaches property and methods to the class object themselves, making more efficient accessing than old mechanism. It will also support broader set of value types that are compatible with the FFI system. For now the old mechanism and new mechanism will co-exist, and we will phase out old mechanism as we migrate most needed features into new one. --- ffi/include/tvm/ffi/memory.h | 8 +- ffi/include/tvm/ffi/reflection/reflection.h | 159 ++++++++++++-------- ffi/include/tvm/ffi/string.h | 12 ++ ffi/src/ffi/object.cc | 77 ++++++++++ ffi/src/ffi/testing.cc | 42 ++++++ ffi/tests/cpp/test_reflection.cc | 10 +- python/tvm/ffi/__init__.py | 1 + python/tvm/ffi/cython/base.pxi | 47 ++++++ python/tvm/ffi/cython/error.pxi | 1 + python/tvm/ffi/cython/function.pxi | 95 ++++++++++++ python/tvm/ffi/cython/ndarray.pxi | 2 +- python/tvm/ffi/registry.py | 1 + python/tvm/ffi/testing.py | 63 ++++++++ tests/python/ffi/test_object.py | 70 +++++++++ 14 files changed, 519 insertions(+), 69 deletions(-) create mode 100644 python/tvm/ffi/testing.py create mode 100644 tests/python/ffi/test_object.py diff --git a/ffi/include/tvm/ffi/memory.h b/ffi/include/tvm/ffi/memory.h index eb317d2bbd72..02537df79cb4 100644 --- a/ffi/include/tvm/ffi/memory.h +++ b/ffi/include/tvm/ffi/memory.h @@ -70,7 +70,7 @@ class ObjAllocatorBase { * \param args The arguments. */ template - inline ObjectPtr make_object(Args&&... args) { + ObjectPtr make_object(Args&&... args) { using Handler = typename Derived::template Handler; static_assert(std::is_base_of::value, "make can only be used to create Object"); T* ptr = Handler::New(static_cast(this), std::forward(args)...); @@ -89,7 +89,7 @@ class ObjAllocatorBase { * \param args The arguments. */ template - inline ObjectPtr make_inplace_array(size_t num_elems, Args&&... args) { + ObjectPtr make_inplace_array(size_t num_elems, Args&&... args) { using Handler = typename Derived::template ArrayHandler; static_assert(std::is_base_of::value, "make_inplace_array can only be used to create Object"); @@ -109,7 +109,9 @@ class SimpleObjAllocator : public ObjAllocatorBase { template class Handler { public: - using StorageType = typename std::aligned_storage::type; + struct alignas(T) StorageType { + char data[sizeof(T)]; + }; template static T* New(SimpleObjAllocator*, Args&&... args) { diff --git a/ffi/include/tvm/ffi/reflection/reflection.h b/ffi/include/tvm/ffi/reflection/reflection.h index bd2f5cb9c76e..6187a74825d6 100644 --- a/ffi/include/tvm/ffi/reflection/reflection.h +++ b/ffi/include/tvm/ffi/reflection/reflection.h @@ -46,7 +46,7 @@ class DefaultValue : public FieldInfoTrait { public: explicit DefaultValue(Any value) : value_(value) {} - void Apply(TVMFFIFieldInfo* info) const { + TVM_FFI_INLINE void Apply(TVMFFIFieldInfo* info) const { info->default_value = AnyView(value_).CopyToTVMFFIAny(); info->flags |= kTVMFFIFieldFlagBitMaskHasDefault; } @@ -65,16 +65,89 @@ class DefaultValue : public FieldInfoTrait { * \returns The byteoffset */ template -inline int64_t GetFieldByteOffsetToObject(T Class::*field_ptr) { +TVM_FFI_INLINE int64_t GetFieldByteOffsetToObject(T Class::*field_ptr) { int64_t field_offset_to_class = reinterpret_cast(&(static_cast(nullptr)->*field_ptr)); return field_offset_to_class - details::ObjectUnsafe::GetObjectOffsetToSubclass(); } +class ReflectionDefBase { + protected: + template + static int FieldGetter(void* field, TVMFFIAny* result) { + TVM_FFI_SAFE_CALL_BEGIN(); + *result = details::AnyUnsafe::MoveAnyToTVMFFIAny(Any(*reinterpret_cast(field))); + TVM_FFI_SAFE_CALL_END(); + } + + template + static int FieldSetter(void* field, const TVMFFIAny* value) { + TVM_FFI_SAFE_CALL_BEGIN(); + *reinterpret_cast(field) = AnyView::CopyFromTVMFFIAny(*value).cast(); + TVM_FFI_SAFE_CALL_END(); + } + + template + static int ObjectCreatorDefault(TVMFFIObjectHandle* result) { + TVM_FFI_SAFE_CALL_BEGIN(); + ObjectPtr obj = make_object(); + *result = details::ObjectUnsafe::MoveObjectPtrToTVMFFIObjectPtr(std::move(obj)); + TVM_FFI_SAFE_CALL_END(); + } + + template + static TVM_FFI_INLINE void ApplyFieldInfoTrait(TVMFFIFieldInfo* info, const T& value) { + if constexpr (std::is_base_of_v>) { + value.Apply(info); + } + if constexpr (std::is_same_v, char*>) { + info->doc = TVMFFIByteArray{value, std::char_traits::length(value)}; + } + } + + template + static TVM_FFI_INLINE void ApplyMethodInfoTrait(TVMFFIMethodInfo* info, const T& value) { + if constexpr (std::is_same_v, char*>) { + info->doc = TVMFFIByteArray{value, std::char_traits::length(value)}; + } + } + + template + static TVM_FFI_INLINE void ApplyExtraInfoTrait(TVMFFITypeExtraInfo* info, const T& value) { + if constexpr (std::is_same_v, char*>) { + info->doc = TVMFFIByteArray{value, std::char_traits::length(value)}; + } + } + template + static TVM_FFI_INLINE Function GetMethod(std::string name, R (Class::*func)(Args...)) { + auto fwrap = [func](const Class* target, Args... params) -> R { + return (const_cast(target)->*func)(std::forward(params)...); + }; + return ffi::Function::FromTyped(fwrap, name); + } + + template + static TVM_FFI_INLINE Function GetMethod(std::string name, R (Class::*func)(Args...) const) { + auto fwrap = [func](const Class* target, Args... params) -> R { + return (target->*func)(std::forward(params)...); + }; + return ffi::Function::FromTyped(fwrap, name); + } + + template + static TVM_FFI_INLINE Function GetMethod(std::string name, Func&& func) { + return ffi::Function::FromTyped(std::forward(func), name); + } +}; + template -class ObjectDef { +class ObjectDef : public ReflectionDefBase { public: - ObjectDef() : type_index_(Class::_GetOrAllocRuntimeTypeIndex()), type_key_(Class::_type_key) {} + template + explicit ObjectDef(ExtraArgs&&... extra_args) + : type_index_(Class::_GetOrAllocRuntimeTypeIndex()), type_key_(Class::_type_key) { + RegisterExtraInfo(std::forward(extra_args)...); + } /*! * \brief Define a readonly field. @@ -90,7 +163,7 @@ class ObjectDef { * \return The reflection definition. */ template - ObjectDef& def_ro(const char* name, T Class::*field_ptr, Extra&&... extra) { + TVM_FFI_INLINE ObjectDef& def_ro(const char* name, T Class::*field_ptr, Extra&&... extra) { RegisterField(name, field_ptr, false, std::forward(extra)...); return *this; } @@ -109,7 +182,8 @@ class ObjectDef { * \return The reflection definition. */ template - ObjectDef& def_rw(const char* name, T Class::*field_ptr, Extra&&... extra) { + TVM_FFI_INLINE ObjectDef& def_rw(const char* name, T Class::*field_ptr, Extra&&... extra) { + static_assert(Class::_type_mutable, "Only mutable classes are supported for writable fields"); RegisterField(name, field_ptr, true, std::forward(extra)...); return *this; } @@ -127,7 +201,7 @@ class ObjectDef { * \return The reflection definition. */ template - ObjectDef& def(const char* name, Func&& func, Extra&&... extra) { + TVM_FFI_INLINE ObjectDef& def(const char* name, Func&& func, Extra&&... extra) { RegisterMethod(name, false, std::forward(func), std::forward(extra)...); return *this; } @@ -145,12 +219,26 @@ class ObjectDef { * \return The reflection definition. */ template - ObjectDef& def_static(const char* name, Func&& func, Extra&&... extra) { + TVM_FFI_INLINE ObjectDef& def_static(const char* name, Func&& func, Extra&&... extra) { RegisterMethod(name, true, std::forward(func), std::forward(extra)...); return *this; } private: + template + void RegisterExtraInfo(ExtraArgs&&... extra_args) { + TVMFFITypeExtraInfo info; + info.total_size = sizeof(Class); + info.creator = nullptr; + info.doc = TVMFFIByteArray{nullptr, 0}; + if constexpr (std::is_default_constructible_v) { + info.creator = ObjectCreatorDefault; + } + // apply extra info traits + ((ApplyExtraInfoTrait(&info, std::forward(extra_args)), ...)); + TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterExtraInfo(type_index_, &info)); + } + template void RegisterField(const char* name, T Class::*field_ptr, bool writable, ExtraArgs&&... extra_args) { @@ -178,30 +266,6 @@ class ObjectDef { TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterField(type_index_, &info)); } - template - static int FieldGetter(void* field, TVMFFIAny* result) { - TVM_FFI_SAFE_CALL_BEGIN(); - *result = details::AnyUnsafe::MoveAnyToTVMFFIAny(Any(*reinterpret_cast(field))); - TVM_FFI_SAFE_CALL_END(); - } - - template - static int FieldSetter(void* field, const TVMFFIAny* value) { - TVM_FFI_SAFE_CALL_BEGIN(); - *reinterpret_cast(field) = AnyView::CopyFromTVMFFIAny(*value).cast(); - TVM_FFI_SAFE_CALL_END(); - } - - template - static void ApplyFieldInfoTrait(TVMFFIFieldInfo* info, const T& value) { - if constexpr (std::is_base_of_v>) { - value.Apply(info); - } - if constexpr (std::is_same_v, char*>) { - info->doc = TVMFFIByteArray{value, std::char_traits::length(value)}; - } - } - // register a method template void RegisterMethod(const char* name, bool is_static, Func&& func, Extra&&... extra) { @@ -214,41 +278,14 @@ class ObjectDef { info.flags |= kTVMFFIFieldFlagBitMaskIsStaticMethod; } // obtain the method function - Function method = GetMethod(std::string(type_key_) + "." + name, std::forward(func)); + Function method = + GetMethod(std::string(type_key_) + "." + name, std::forward(func)); info.method = AnyView(method).CopyToTVMFFIAny(); // apply method info traits ((ApplyMethodInfoTrait(&info, std::forward(extra)), ...)); TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterMethod(type_index_, &info)); } - template - static void ApplyMethodInfoTrait(TVMFFIMethodInfo* info, const T& value) { - if constexpr (std::is_same_v, char*>) { - info->doc = TVMFFIByteArray{value, std::char_traits::length(value)}; - } - } - - template - static Function GetMethod(std::string name, R (Class::*func)(Args...)) { - auto fwrap = [func](const Class* target, Args... params) -> R { - return (const_cast(target)->*func)(std::forward(params)...); - }; - return ffi::Function::FromTyped(fwrap, name); - } - - template - static Function GetMethod(std::string name, R (Class::*func)(Args...) const) { - auto fwrap = [func](const Class* target, Args... params) -> R { - return (target->*func)(std::forward(params)...); - }; - return ffi::Function::FromTyped(fwrap, name); - } - - template - static Function GetMethod(std::string name, Func&& func) { - return ffi::Function::FromTyped(std::forward(func), name); - } - int32_t type_index_; const char* type_key_; }; diff --git a/ffi/include/tvm/ffi/string.h b/ffi/include/tvm/ffi/string.h index 19df2e8e3dcf..dee2d89c0854 100644 --- a/ffi/include/tvm/ffi/string.h +++ b/ffi/include/tvm/ffi/string.h @@ -306,6 +306,18 @@ class String : public ObjectRef { return Bytes::memncmp(data(), other, size(), std::strlen(other)); } + /*! + * \brief Compares this to other + * + * \param other The TVMFFIByteArray to compare with. + * + * \return zero if both char sequences compare equal. negative if this appear + * before other, positive otherwise. + */ + int compare(const TVMFFIByteArray& other) const { + return Bytes::memncmp(data(), other.data, size(), other.size); + } + /*! * \brief Returns a pointer to the char array in the string. * diff --git a/ffi/src/ffi/object.cc b/ffi/src/ffi/object.cc index 793d3e27283a..fa77e2b26401 100644 --- a/ffi/src/ffi/object.cc +++ b/ffi/src/ffi/object.cc @@ -315,6 +315,83 @@ class TypeTable { Map type_key2index_; std::vector any_pool_; }; + +void MakeObjectFromPackedArgs(ffi::PackedArgs args, Any* ret) { + String type_key = args[0].cast(); + TVM_FFI_ICHECK(args.size() % 2 == 1); + + int32_t type_index; + TVMFFIByteArray type_key_array = TVMFFIByteArray{type_key.data(), type_key.size()}; + TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeKeyToIndex(&type_key_array, &type_index)); + const TVMFFITypeInfo* type_info = TVMFFIGetTypeInfo(type_index); + if (type_info == nullptr) { + TVM_FFI_THROW(RuntimeError) << "Cannot find type `" << type_key << "`"; + } + + if (type_info->extra_info == nullptr || type_info->extra_info->creator == nullptr) { + TVM_FFI_THROW(RuntimeError) << "Type `" << type_key << "` does not support reflection creation"; + } + TVMFFIObjectHandle handle; + TVM_FFI_CHECK_SAFE_CALL(type_info->extra_info->creator(&handle)); + ObjectPtr ptr = + details::ObjectUnsafe::ObjectPtrFromOwned(static_cast(handle)); + + std::vector keys; + std::vector keys_found; + + for (int i = 1; i < args.size(); i += 2) { + keys.push_back(args[i].cast()); + } + keys_found.resize(keys.size(), false); + + auto search_field = [&](const TVMFFIByteArray& field_name) { + for (size_t i = 0; i < keys.size(); ++i) { + if (keys_found[i]) continue; + if (keys[i].compare(field_name) == 0) { + return i; + } + } + return keys.size(); + }; + + auto update_fields = [&](const TVMFFITypeInfo* tinfo) { + for (int i = 0; i < tinfo->num_fields; ++i) { + const TVMFFIFieldInfo* field_info = tinfo->fields + i; + size_t arg_index = search_field(field_info->name); + void* field_addr = reinterpret_cast(ptr.get()) + field_info->offset; + if (arg_index < keys.size()) { + AnyView field_value = args[arg_index * 2 + 2]; + field_info->setter(field_addr, reinterpret_cast(&field_value)); + keys_found[arg_index] = true; + } else if (field_info->flags & kTVMFFIFieldFlagBitMaskHasDefault) { + field_info->setter(field_addr, &(field_info->default_value)); + } else { + TVM_FFI_THROW(TypeError) << "Required field `" + << String(field_info->name.data, field_info->name.size) + << "` not set in type `" << type_key << "`"; + } + } + }; + + // iterate through acenstors in parent to child order + // skip the first one since it is always the root object + TVM_FFI_ICHECK(type_info->type_acenstors[0] == TypeIndex::kTVMFFIObject); + for (int i = 1; i < type_info->type_depth; ++i) { + update_fields(TVMFFIGetTypeInfo(type_info->type_acenstors[i])); + } + update_fields(type_info); + + for (size_t i = 0; i < keys.size(); ++i) { + if (!keys_found[i]) { + TVM_FFI_THROW(TypeError) << "Type `" << type_key << "` does not have field `" << keys[i] + << "`"; + } + } + *ret = ObjectRef(ptr); +} + +TVM_FFI_REGISTER_GLOBAL("ffi.MakeObjectFromPackedArgs").set_body_packed(MakeObjectFromPackedArgs); + } // namespace ffi } // namespace tvm diff --git a/ffi/src/ffi/testing.cc b/ffi/src/ffi/testing.cc index 050ac28c476e..6bc7968eab06 100644 --- a/ffi/src/ffi/testing.cc +++ b/ffi/src/ffi/testing.cc @@ -17,7 +17,10 @@ * under the License. */ // This file is used for testing the FFI API. +#include +#include #include +#include #include #include @@ -26,6 +29,45 @@ namespace tvm { namespace ffi { +class TestObjectBase : public Object { + public: + int64_t v_i64; + double v_f64; + String v_str; + + int64_t AddI64(int64_t other) const { return v_i64 + other; } + + // declare as one slot, with float as overflow + static constexpr bool _type_mutable = true; + static constexpr uint32_t _type_child_slots = 1; + static constexpr const char* _type_key = "testing.TestObjectBase"; + TVM_FFI_DECLARE_BASE_OBJECT_INFO(TestObjectBase, Object); +}; + +class TestObjectDerived : public TestObjectBase { + public: + Map v_map; + Array v_array; + + // declare as one slot, with float as overflow + static constexpr const char* _type_key = "testing.TestObjectDerived"; + TVM_FFI_DECLARE_FINAL_OBJECT_INFO(TestObjectDerived, TestObjectBase); +}; + +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + + refl::ObjectDef() + .def_rw("v_i64", &TestObjectBase::v_i64, refl::DefaultValue(10), "i64 field") + .def_ro("v_f64", &TestObjectBase::v_f64, refl::DefaultValue(10.0)) + .def_rw("v_str", &TestObjectBase::v_str, refl::DefaultValue("hello")) + .def("add_i64", &TestObjectBase::AddI64, "add_i64 method"); + + refl::ObjectDef() + .def_ro("v_map", &TestObjectDerived::v_map) + .def_ro("v_array", &TestObjectDerived::v_array); +}); + void TestRaiseError(String kind, String msg) { throw ffi::Error(kind, msg, TVM_FFI_TRACEBACK_HERE); } diff --git a/ffi/tests/cpp/test_reflection.cc b/ffi/tests/cpp/test_reflection.cc index 64b3a6f590eb..450cb9dbcbf7 100644 --- a/ffi/tests/cpp/test_reflection.cc +++ b/ffi/tests/cpp/test_reflection.cc @@ -32,13 +32,15 @@ using namespace tvm::ffi::testing; struct A : public Object { int64_t x; int64_t y; + + static constexpr bool _type_mutable = true; }; TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::ObjectDef() - .def_rw("value", &TFloatObj::value, "float value field", refl::DefaultValue(10.0)) + .def_ro("value", &TFloatObj::value, "float value field", refl::DefaultValue(10.0)) .def("sub", [](const TFloatObj* self, double other) -> double { return self->value - other; }) .def("add", &TFloatObj::Add, "add method"); @@ -47,7 +49,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ .def_static("static_add", &TInt::StaticAdd, "static add method"); refl::ObjectDef() - .def_ro("dtype", &TPrimExprObj::dtype, "dtype field", refl::DefaultValue("float")) + .def_rw("dtype", &TPrimExprObj::dtype, "dtype field", refl::DefaultValue("float")) .def_ro("value", &TPrimExprObj::value, "value field", refl::DefaultValue(0)) .def("sub", [](TPrimExprObj* self, double other) -> double { // this is ok because TPrimExprObj is declared asmutable @@ -89,7 +91,7 @@ TEST(Reflection, FieldInfo) { const TVMFFIFieldInfo* info_float = reflection::GetFieldInfo("test.Float", "value"); EXPECT_EQ(info_float->default_value.v_float64, 10.0); EXPECT_TRUE(info_float->flags & kTVMFFIFieldFlagBitMaskHasDefault); - EXPECT_TRUE(info_float->flags & kTVMFFIFieldFlagBitMaskWritable); + EXPECT_FALSE(info_float->flags & kTVMFFIFieldFlagBitMaskWritable); EXPECT_EQ(Bytes(info_float->doc).operator std::string(), "float value field"); const TVMFFIFieldInfo* info_prim_expr_dtype = reflection::GetFieldInfo("test.PrimExpr", "dtype"); @@ -97,7 +99,7 @@ TEST(Reflection, FieldInfo) { EXPECT_EQ(default_value.cast(), "float"); EXPECT_EQ(default_value.as().value().use_count(), 2); EXPECT_TRUE(info_prim_expr_dtype->flags & kTVMFFIFieldFlagBitMaskHasDefault); - EXPECT_FALSE(info_prim_expr_dtype->flags & kTVMFFIFieldFlagBitMaskWritable); + EXPECT_TRUE(info_prim_expr_dtype->flags & kTVMFFIFieldFlagBitMaskWritable); EXPECT_EQ(Bytes(info_prim_expr_dtype->doc).operator std::string(), "dtype field"); } diff --git a/python/tvm/ffi/__init__.py b/python/tvm/ffi/__init__.py index 0a8b223405b9..b507064e34d9 100644 --- a/python/tvm/ffi/__init__.py +++ b/python/tvm/ffi/__init__.py @@ -30,6 +30,7 @@ from .ndarray import cpu, cuda, rocm, opencl, metal, vpi, vulkan, ext_dev, hexagon, webgpu from .ndarray import from_dlpack, NDArray, Shape from .container import Array, Map +from . import testing __all__ = [ diff --git a/python/tvm/ffi/cython/base.pxi b/python/tvm/ffi/cython/base.pxi index e18d52fc8d84..50831be462ad 100644 --- a/python/tvm/ffi/cython/base.pxi +++ b/python/tvm/ffi/cython/base.pxi @@ -134,6 +134,52 @@ cdef extern from "tvm/ffi/c_api.h": void* handle, const TVMFFIAny* args, int32_t num_args, TVMFFIAny* result) noexcept + cdef enum TVMFFIFieldFlagBitMask: + kTVMFFIFieldFlagBitMaskWritable = 1 << 0 + kTVMFFIFieldFlagBitMaskHasDefault = 1 << 1 + kTVMFFIFieldFlagBitMaskIsStaticMethod = 1 << 2 + + ctypedef int (*TVMFFIFieldGetter)(void* field, TVMFFIAny* result) noexcept; + ctypedef int (*TVMFFIFieldSetter)(void* field, const TVMFFIAny* value) noexcept; + ctypedef int (*TVMFFIObjectCreator)(TVMFFIObjectHandle* result) noexcept; + + ctypedef struct TVMFFIFieldInfo: + TVMFFIByteArray name + TVMFFIByteArray doc + TVMFFIByteArray type_schema + int64_t flags + int64_t size + int64_t alignment + int64_t offset + TVMFFIFieldGetter getter + TVMFFIFieldSetter setter + TVMFFIAny default_value + int32_t field_static_type_index + + ctypedef struct TVMFFIMethodInfo: + TVMFFIByteArray name + TVMFFIByteArray doc + TVMFFIByteArray type_schema + int64_t flags + TVMFFIAny method + + ctypedef struct TVMFFITypeExtraInfo: + TVMFFIByteArray doc + TVMFFIObjectCreator creator + int64_t total_size + + ctypedef struct TVMFFITypeInfo: + int32_t type_index + int32_t type_depth + TVMFFIByteArray type_key + const int32_t* type_acenstors + uint64_t type_key_hash + int32_t num_fields + int32_t num_methods + const TVMFFIFieldInfo* fields + const TVMFFIMethodInfo* methods + const TVMFFITypeExtraInfo* extra_info + int TVMFFIObjectFree(TVMFFIObjectHandle obj) nogil int TVMFFIObjectGetTypeIndex(TVMFFIObjectHandle obj) nogil int TVMFFIFunctionCall(TVMFFIObjectHandle func, TVMFFIAny* args, int32_t num_args, @@ -161,6 +207,7 @@ cdef extern from "tvm/ffi/c_api.h": int TVMFFINDArrayToDLPack(TVMFFIObjectHandle src, DLManagedTensor** out) nogil int TVMFFINDArrayToDLPackVersioned(TVMFFIObjectHandle src, DLManagedTensorVersioned** out) nogil + const TVMFFITypeInfo* TVMFFIGetTypeInfo(int32_t type_index) nogil TVMFFIByteArray* TVMFFIBytesGetByteArrayPtr(TVMFFIObjectHandle obj) nogil TVMFFIErrorCell* TVMFFIErrorGetCellPtr(TVMFFIObjectHandle obj) nogil TVMFFIShapeCell* TVMFFIShapeGetCellPtr(TVMFFIObjectHandle obj) nogil diff --git a/python/tvm/ffi/cython/error.pxi b/python/tvm/ffi/cython/error.pxi index 3a19573b8f94..8da630873ede 100644 --- a/python/tvm/ffi/cython/error.pxi +++ b/python/tvm/ffi/cython/error.pxi @@ -113,6 +113,7 @@ cdef class Error(Object): def traceback(self): return bytearray_to_str(&(TVMFFIErrorGetCellPtr(self.chandle).traceback)) + _register_object_by_index(kTVMFFIError, Error) diff --git a/python/tvm/ffi/cython/function.pxi b/python/tvm/ffi/cython/function.pxi index 294a1246b27b..640fff7af557 100644 --- a/python/tvm/ffi/cython/function.pxi +++ b/python/tvm/ffi/cython/function.pxi @@ -230,6 +230,101 @@ class Function(Object): _register_object_by_index(kTVMFFIFunction, Function) +cdef class FieldGetter: + cdef TVMFFIFieldGetter getter + cdef int64_t offset + + def __call__(self, Object obj): + cdef TVMFFIAny result + cdef int c_api_ret_code + cdef void* field_ptr = ((obj).chandle) + self.offset + result.type_index = kTVMFFINone + result.v_int64 = 0 + c_api_ret_code = self.getter(field_ptr, &result) + CHECK_CALL(c_api_ret_code) + return make_ret(result) + + +cdef class FieldSetter: + cdef TVMFFIFieldSetter setter + cdef int64_t offset + + def __call__(self, Object obj, value): + cdef TVMFFIAny[1] packed_args + cdef int c_api_ret_code + cdef void* field_ptr = ((obj).chandle) + self.offset + cdef int nargs = 1 + temp_args = [] + make_args((value,), &packed_args[0], temp_args) + c_api_ret_code = self.setter(field_ptr, &packed_args[0]) + # NOTE: logic is same as check_call + # directly inline here to simplify traceback + if c_api_ret_code == 0: + return + elif c_api_ret_code == -2: + raise_existing_error() + raise move_from_last_error().py_error() + + +cdef _get_method_from_method_info(const TVMFFIMethodInfo* method): + cdef TVMFFIAny result + CHECK_CALL(TVMFFIAnyViewToOwnedAny(&(method.method), &result)) + return make_ret(result) + + +def _add_class_attrs_by_reflection(int type_index, object cls): + """Decorate the class attrs by reflection""" + cdef const TVMFFITypeInfo* info = TVMFFIGetTypeInfo(type_index) + cdef const TVMFFIFieldInfo* field + cdef const TVMFFIMethodInfo* method + cdef int num_fields = info.num_fields + cdef int num_methods = info.num_methods + + for i in range(num_fields): + # attach fields to the class + field = &(info.fields[i]) + getter = FieldGetter.__new__(FieldGetter) + (getter).getter = field.getter + (getter).offset = field.offset + setter = FieldSetter.__new__(FieldSetter) + (setter).setter = field.setter + (setter).offset = field.offset + if (field.flags & kTVMFFIFieldFlagBitMaskWritable) == 0: + setter = None + doc = ( + py_str(PyBytes_FromStringAndSize(field.doc.data, field.doc.size)) + if field.doc.size != 0 + else None + ) + name = py_str(PyBytes_FromStringAndSize(field.name.data, field.name.size)) + setattr(cls, name, property(getter, setter, doc=doc)) + + for i in range(num_methods): + # attach methods to the class + method = &(info.methods[i]) + name = py_str(PyBytes_FromStringAndSize(method.name.data, method.name.size)) + doc = ( + py_str(PyBytes_FromStringAndSize(method.doc.data, method.doc.size)) + if method.doc.size != 0 + else None + ) + method_func = _get_method_from_method_info(method) + + if method.flags & kTVMFFIFieldFlagBitMaskIsStaticMethod: + method_pyfunc = staticmethod(method_func) + else: + def method_pyfunc(self, *args): + return method_func(self, *args) + + if doc is not None: + method_pyfunc.__doc__ = doc + method_pyfunc.__name__ = name + + setattr(cls, name, method_pyfunc) + + return cls + + def _register_global_func(name, pyfunc, override): cdef TVMFFIObjectHandle chandle cdef int c_api_ret_code diff --git a/python/tvm/ffi/cython/ndarray.pxi b/python/tvm/ffi/cython/ndarray.pxi index b8534b41b38b..9dfe1222dc7e 100644 --- a/python/tvm/ffi/cython/ndarray.pxi +++ b/python/tvm/ffi/cython/ndarray.pxi @@ -23,7 +23,6 @@ _CLASS_NDARRAY = None def _set_class_ndarray(cls): global _CLASS_NDARRAY _CLASS_NDARRAY = cls - _register_object_by_index(kTVMFFINDArray, cls) cdef const char* _c_str_dltensor = "dltensor" @@ -268,6 +267,7 @@ cdef class NDArray(Object): _set_class_ndarray(NDArray) +_register_object_by_index(kTVMFFINDArray, NDArray) cdef inline object make_ret_dltensor(TVMFFIAny result): diff --git a/python/tvm/ffi/registry.py b/python/tvm/ffi/registry.py index 58df08d90c56..9302b251733b 100644 --- a/python/tvm/ffi/registry.py +++ b/python/tvm/ffi/registry.py @@ -50,6 +50,7 @@ def register(cls): if _SKIP_UNKNOWN_OBJECTS: return cls raise ValueError("Cannot find object type index for %s" % object_name) + core._add_class_attrs_by_reflection(type_index, cls) core._register_object_by_index(type_index, cls) return cls diff --git a/python/tvm/ffi/testing.py b/python/tvm/ffi/testing.py new file mode 100644 index 000000000000..843a10c896a8 --- /dev/null +++ b/python/tvm/ffi/testing.py @@ -0,0 +1,63 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Testing utilities.""" + +from . import _ffi_api +from .core import Object +from .registry import register_object + + +@register_object("testing.TestObjectBase") +class TestObjectBase(Object): + """ + Test object base class. + """ + + +@register_object("testing.TestObjectDerived") +class TestObjectDerived(TestObjectBase): + """ + Test object derived class. + """ + + +def create_object(type_key: str, **kwargs) -> Object: + """ + Make an object by reflection. + + Parameters + ---------- + type_key : str + The type key of the object. + kwargs : dict + The keyword arguments to the object. + + Returns + ------- + obj : object + The created object. + + Note + ---- + This function is only used for testing purposes and should + not be used in other cases. + """ + args = [type_key] + for k, v in kwargs.items(): + args.append(k) + args.append(v) + return _ffi_api.MakeObjectFromPackedArgs(*args) diff --git a/tests/python/ffi/test_object.py b/tests/python/ffi/test_object.py new file mode 100644 index 000000000000..d333cbca089c --- /dev/null +++ b/tests/python/ffi/test_object.py @@ -0,0 +1,70 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest + +from tvm import ffi as tvm_ffi + + +def test_make_object(): + # with default values + obj0 = tvm_ffi.testing.create_object("testing.TestObjectBase") + assert obj0.v_i64 == 10 + assert obj0.v_f64 == 10.0 + assert obj0.v_str == "hello" + + +def test_method(): + obj0 = tvm_ffi.testing.create_object("testing.TestObjectBase", v_i64=12) + assert obj0.add_i64(1) == 13 + assert type(obj0).add_i64.__doc__ == "add_i64 method" + assert type(obj0).v_i64.__doc__ == "i64 field" + + +def test_setter(): + # test setter + obj0 = tvm_ffi.testing.create_object("testing.TestObjectBase", v_i64=10, v_str="hello") + assert obj0.v_i64 == 10 + obj0.v_i64 = 11 + assert obj0.v_i64 == 11 + obj0.v_str = "world" + assert obj0.v_str == "world" + + with pytest.raises(TypeError): + obj0.v_str = 1 + + with pytest.raises(TypeError): + obj0.v_i64 = "hello" + + +def test_derived_object(): + with pytest.raises(TypeError): + obj0 = tvm_ffi.testing.create_object("testing.TestObjectDerived") + + v_map = tvm_ffi.convert({"a": 1}) + v_array = tvm_ffi.convert([1, 2, 3]) + + obj0 = tvm_ffi.testing.create_object( + "testing.TestObjectDerived", v_i64=20, v_map=v_map, v_array=v_array + ) + assert obj0.v_map.same_as(v_map) + assert obj0.v_array.same_as(v_array) + assert obj0.v_i64 == 20 + assert obj0.v_f64 == 10.0 + assert obj0.v_str == "hello" + + obj0.v_i64 = 21 + assert obj0.v_i64 == 21