From 3268428516b17f57eff583c8c08c33782d2f6d1e Mon Sep 17 00:00:00 2001 From: tqchen Date: Sun, 31 Aug 2025 16:23:22 -0400 Subject: [PATCH] [FFI][ABI] Introduce weak rc support This PR adds weak ref counter support to the FFI ABI. Weak rc is useful when we want to break cyclic dependencies. - When a strong rc goes to zero, we call the destructor of the object, but not freeing the memory - When both strong and weak rc goes to zero, we call the memory free operation The weak rc mechanism is useful when we want to break cyclic dependencies in object, where the weak rc can keep memory alive but the destructor is called. As of now, because we deliberately avoid cyles in codebase, we do not have strong use-case for weak rc. However, given weak rc is common practice in shared_ptr, Rust RC, and also used in torch's c10::intrusive_ptr. It is better to make sure the ABI is future compatible to such use-cases before we freeze. This PR implements weak rc as a u32 counter and strong rc as a u64 counter, with the following design consideration. - Weak rc is very rarely used and u32 is sufficient. - Keeping weak rc in u32 allows us to keep object header size to 24 bytes, saving extra 8 bytes(considering alignment) We also need to update deleter to take flags that consider both weak and strong deletion events. The implementation tries to optimize common case where both strong and weak goes to 0 at the same time and call deleter once with both flags set. --- ffi/include/tvm/ffi/c_api.h | 65 ++++- ffi/include/tvm/ffi/memory.h | 46 +-- ffi/include/tvm/ffi/object.h | 261 +++++++++++++++++- ffi/include/tvm/ffi/type_traits.h | 2 +- ffi/pyproject.toml | 2 +- ffi/python/tvm_ffi/cython/base.pxi | 2 +- ffi/python/tvm_ffi/cython/dtype.pxi | 2 +- ffi/python/tvm_ffi/cython/object.pxi | 2 +- ffi/src/ffi/object.cc | 8 +- ffi/tests/cpp/test_c_ffi_abi.cc | 2 +- ffi/tests/cpp/test_object.cc | 119 ++++++++ jvm/native/src/main/native/jni_helper_func.h | 2 +- .../native/org_apache_tvm_native_c_api.cc | 2 +- src/tir/transforms/make_packed_api.cc | 4 +- web/src/ctypes.ts | 6 +- web/src/runtime.ts | 8 +- 16 files changed, 475 insertions(+), 58 deletions(-) diff --git a/ffi/include/tvm/ffi/c_api.h b/ffi/include/tvm/ffi/c_api.h index f099898b158d..b4f59526a900 100644 --- a/ffi/include/tvm/ffi/c_api.h +++ b/ffi/include/tvm/ffi/c_api.h @@ -156,6 +156,36 @@ typedef enum { /*! \brief Handle to Object from C API's pov */ typedef void* TVMFFIObjectHandle; +/*! + * \brief bitmask of the object deleter flag. + */ +#ifdef __cplusplus +enum TVMFFIObjectDeleterFlagBitMask : int32_t { +#else +typedef enum { +#endif + /*! + * \brief deleter action when strong reference count becomes zero. + * Need to call destructor of the object but not free the memory block. + */ + kTVMFFIObjectDeleterFlagBitMaskStrong = 1 << 0, + /*! + * \brief deleter action when weak reference count becomes zero. + * Need to free the memory block. + */ + kTVMFFIObjectDeleterFlagBitMaskWeak = 1 << 1, + /*! + * \brief deleter action when both strong and weak reference counts become zero. + * \note This is the most common case. + */ + kTVMFFIObjectDeleterFlagBitMaskBoth = + (kTVMFFIObjectDeleterFlagBitMaskStrong | kTVMFFIObjectDeleterFlagBitMaskWeak), +#ifdef __cplusplus +}; +#else +} TVMFFIObjectDeleterFlagBitMask; +#endif + /*! * \brief C-based type of all FFI object header that allocates on heap. * \note TVMFFIObject and TVMFFIAny share the common type_index header @@ -166,11 +196,22 @@ typedef struct TVMFFIObject { * \note The type index of Object and Any are shared in FFI. */ int32_t type_index; - /*! \brief Reference counter of the object. */ - int32_t ref_counter; + /*! + * \brief Weak reference counter of the object, for compatiblity with weak_ptr design. + * \note Use u32 to ensure that overall object stays within 24-byte boundary, usually + * manipulation of weak counter is less common than strong counter. + */ + uint32_t weak_ref_count; + /*! \brief Strong reference counter of the object. */ + uint64_t strong_ref_count; union { - /*! \brief Deleter to be invoked when reference counter goes to zero. */ - void (*deleter)(struct TVMFFIObject* self); + /*! + * \brief Deleter to be invoked when strong reference counter goes to zero. + * \param self The self object handle. + * \param flags The flags to indicate deletion behavior. + * \sa TVMFFIObjectDeleterFlagBitMask + */ + void (*deleter)(struct TVMFFIObject* self, int flags); /*! * \brief auxilary field to TVMFFIObject is always 8 bytes aligned. * \note This helps us to ensure cross platform compatibility. @@ -307,13 +348,19 @@ typedef struct { // Section: Basic object API //------------------------------------------------------------ /*! - * \brief Free an object handle by decreasing reference + * \brief Increas the strong reference count of an object handle + * \param obj The object handle. + * \note Internally we increase the reference counter of the object. + * \return 0 when success, nonzero when failure happens + */ +TVM_FFI_DLL int TVMFFIObjectIncRef(TVMFFIObjectHandle obj); + +/*! + * \brief Free an object handle by decreasing strong reference * \param obj The object handle. - * \note Internally we decrease the reference counter of the object. - * The object will be freed when every reference to the object are removed. * \return 0 when success, nonzero when failure happens */ -TVM_FFI_DLL int TVMFFIObjectFree(TVMFFIObjectHandle obj); +TVM_FFI_DLL int TVMFFIObjectDecRef(TVMFFIObjectHandle obj); /*! * \brief Convert type key to type index. @@ -470,7 +517,7 @@ TVM_FFI_DLL int TVMFFIDataTypeFromString(const TVMFFIByteArray* str, DLDataType* * \param dtype The DLDataType to convert. * \param out The output string. * \return 0 when success, nonzero when failure happens -* \note out is a String object that needs to be freed by the caller via TVMFFIObjectFree. +* \note out is a String object that needs to be freed by the caller via TVMFFIObjectDecRef. The content of string can be accessed via TVMFFIObjectGetByteArrayPtr. * \note The input dtype is a pointer to the DLDataType to avoid ABI compatibility issues. diff --git a/ffi/include/tvm/ffi/memory.h b/ffi/include/tvm/ffi/memory.h index 02537df79cb4..533d0004274f 100644 --- a/ffi/include/tvm/ffi/memory.h +++ b/ffi/include/tvm/ffi/memory.h @@ -33,7 +33,7 @@ namespace tvm { namespace ffi { /*! \brief Deleter function for obeject */ -typedef void (*FObjectDeleter)(TVMFFIObject* obj); +typedef void (*FObjectDeleter)(TVMFFIObject* obj, int flags); /*! * \brief Allocate an object using default allocator. @@ -75,7 +75,8 @@ class ObjAllocatorBase { static_assert(std::is_base_of::value, "make can only be used to create Object"); T* ptr = Handler::New(static_cast(this), std::forward(args)...); TVMFFIObject* ffi_ptr = details::ObjectUnsafe::GetHeader(ptr); - ffi_ptr->ref_counter = 1; + ffi_ptr->strong_ref_count = 1; + ffi_ptr->weak_ref_count = 1; ffi_ptr->type_index = T::RuntimeTypeIndex(); ffi_ptr->deleter = Handler::Deleter(); return details::ObjectUnsafe::ObjectPtrFromOwned(ptr); @@ -96,7 +97,8 @@ class ObjAllocatorBase { ArrayType* ptr = Handler::New(static_cast(this), num_elems, std::forward(args)...); TVMFFIObject* ffi_ptr = details::ObjectUnsafe::GetHeader(ptr); - ffi_ptr->ref_counter = 1; + ffi_ptr->strong_ref_count = 1; + ffi_ptr->weak_ref_count = 1; ffi_ptr->type_index = ArrayType::RuntimeTypeIndex(); ffi_ptr->deleter = Handler::Deleter(); return details::ObjectUnsafe::ObjectPtrFromOwned(ptr); @@ -136,14 +138,18 @@ class SimpleObjAllocator : public ObjAllocatorBase { static FObjectDeleter Deleter() { return Deleter_; } private: - static void Deleter_(TVMFFIObject* objptr) { + static void Deleter_(TVMFFIObject* objptr, int flags) { T* tptr = details::ObjectUnsafe::RawObjectPtrFromUnowned(objptr); - // It is important to do tptr->T::~T(), - // so that we explicitly call the specific destructor - // instead of tptr->~T(), which could mean the intention - // call a virtual destructor(which may not be available and is not required). - tptr->T::~T(); - delete reinterpret_cast(tptr); + if (flags & kTVMFFIObjectDeleterFlagBitMaskStrong) { + // It is important to do tptr->T::~T(), + // so that we explicitly call the specific destructor + // instead of tptr->~T(), which could mean the intention + // call a virtual destructor(which may not be available and is not required). + tptr->T::~T(); + } + if (flags & kTVMFFIObjectDeleterFlagBitMaskWeak) { + delete reinterpret_cast(tptr); + } } }; @@ -182,15 +188,19 @@ class SimpleObjAllocator : public ObjAllocatorBase { static FObjectDeleter Deleter() { return Deleter_; } private: - static void Deleter_(TVMFFIObject* objptr) { + static void Deleter_(TVMFFIObject* objptr, int flags) { ArrayType* tptr = details::ObjectUnsafe::RawObjectPtrFromUnowned(objptr); - // It is important to do tptr->ArrayType::~ArrayType(), - // so that we explicitly call the specific destructor - // instead of tptr->~ArrayType(), which could mean the intention - // call a virtual destructor(which may not be available and is not required). - tptr->ArrayType::~ArrayType(); - StorageType* p = reinterpret_cast(tptr); - delete[] p; + if (flags & kTVMFFIObjectDeleterFlagBitMaskStrong) { + // It is important to do tptr->ArrayType::~ArrayType(), + // so that we explicitly call the specific destructor + // instead of tptr->~ArrayType(), which could mean the intention + // call a virtual destructor(which may not be available and is not required). + tptr->ArrayType::~ArrayType(); + } + if (flags & kTVMFFIObjectDeleterFlagBitMaskWeak) { + StorageType* p = reinterpret_cast(tptr); + delete[] p; + } } }; }; diff --git a/ffi/include/tvm/ffi/object.h b/ffi/include/tvm/ffi/object.h index cf282a6e2744..cc5ee8d94585 100644 --- a/ffi/include/tvm/ffi/object.h +++ b/ffi/include/tvm/ffi/object.h @@ -143,7 +143,8 @@ class Object { public: Object() { - header_.ref_counter = 0; + header_.strong_ref_count = 0; + header_.weak_ref_count = 0; header_.deleter = nullptr; } /*! @@ -197,9 +198,9 @@ class Object { int32_t use_count() const { // only need relaxed load of counters #ifdef _MSC_VER - return (reinterpret_cast(&header_.ref_counter))[0]; // NOLINT(*) + return (reinterpret_cast(&header_.strong_ref_count))[0]; // NOLINT(*) #else - return __atomic_load_n(&(header_.ref_counter), __ATOMIC_RELAXED); + return __atomic_load_n(&(header_.strong_ref_count), __ATOMIC_RELAXED); #endif } @@ -230,33 +231,121 @@ class Object { static int32_t _GetOrAllocRuntimeTypeIndex() { return TypeIndex::kTVMFFIObject; } private: - /*! \brief increase reference count */ + /*! \brief increase strong reference count, the caller must already hold a strong reference */ void IncRef() { #ifdef _MSC_VER - _InterlockedIncrement(reinterpret_cast(&header_.ref_counter)); // NOLINT(*) + _InterlockedIncrement64( + reinterpret_cast(&header_.strong_ref_count)); // NOLINT(*) #else - __atomic_fetch_add(&(header_.ref_counter), 1, __ATOMIC_RELAXED); + __atomic_fetch_add(&(header_.strong_ref_count), 1, __ATOMIC_RELAXED); +#endif + } + /*! + * \brief Try to lock the object to increase the strong reference count, + * the caller must already hold a strong reference. + * \return whether the lock call is successful and object is still alive. + */ + bool TryPromoteWeakPtr() { +#ifdef _MSC_VER + uint64_t old_count = + (reinterpret_cast(&header_.strong_ref_count))[0]; // NOLINT(*) + while (old_count > 0) { + uint64_t new_count = old_count + 1; + uint64_t old_count_loaded = _InterlockedCompareExchange64( + reinterpret_cast(&header_.strong_ref_count), new_count, old_count); + if (old_count == old_count_loaded) { + return true; + } + old_count = old_count_loaded; + } + return false; +#else + uint64_t old_count = __atomic_load_n(&(header_.strong_ref_count), __ATOMIC_RELAXED); + while (old_count > 0) { + // must do CAS to ensure that we are the only one that increases the reference count + // avoid condition when two threads tries to promote weak to strong at same time + // or when strong deletion happens between the load and the CAS + uint64_t new_count = old_count + 1; + if (__atomic_compare_exchange_n(&(header_.strong_ref_count), &old_count, new_count, true, + __ATOMIC_ACQ_REL, __ATOMIC_RELAXED)) { + return true; + } + } + return false; +#endif + } + + /*! \brief increase weak reference count */ + void IncWeakRef() { +#ifdef _MSC_VER + _InterlockedIncrement(reinterpret_cast(&header_.weak_ref_count)); // NOLINT(*) +#else + __atomic_fetch_add(&(header_.weak_ref_count), 1, __ATOMIC_RELAXED); #endif } - /*! \brief decrease reference count and delete the object */ + /*! \brief decrease strong reference count and delete the object */ void DecRef() { #ifdef _MSC_VER - if (_InterlockedDecrement( // - reinterpret_cast(&header_.ref_counter)) == 0) { // NOLINT(*) + // use simpler impl in windows to ensure correctness + if (_InterlockedDecrement64( // + reinterpret_cast(&header_.strong_ref_count)) == 0) { // NOLINT(*) // full barrrier is implicit in InterlockedDecrement if (header_.deleter != nullptr) { - header_.deleter(&(this->header_)); + header_.deleter(&(this->header_), kTVMFFIObjectDeleterFlagBitMaskStrong); + } + if (_InterlockedDecrement( // + reinterpret_cast(&header_.weak_ref_count)) == 0) { // NOLINT(*) + if (header_.deleter != nullptr) { + header_.deleter(&(this->header_), kTVMFFIObjectDeleterFlagBitMaskWeak); + } } } #else // first do a release, note we only need to acquire for deleter - if (__atomic_fetch_sub(&(header_.ref_counter), 1, __ATOMIC_RELEASE) == 1) { - // only acquire when we need to call deleter - // in this case we need to ensure all previous writes are visible + if (__atomic_fetch_sub(&(header_.strong_ref_count), 1, __ATOMIC_RELEASE) == 1) { + if (__atomic_load_n(&(header_.weak_ref_count), __ATOMIC_RELAXED) == 1) { + // common case, we need to delete both the object and the memory block + // only acquire when we need to call deleter + __atomic_thread_fence(__ATOMIC_ACQUIRE); + if (header_.deleter != nullptr) { + // call deleter once + header_.deleter(&(this->header_), kTVMFFIObjectDeleterFlagBitMaskBoth); + } + } else { + // Slower path: there is still a weak reference left + __atomic_thread_fence(__ATOMIC_ACQUIRE); + // call destructor first, then decrease weak reference count + if (header_.deleter != nullptr) { + header_.deleter(&(this->header_), kTVMFFIObjectDeleterFlagBitMaskStrong); + } + // now decrease weak reference count + if (__atomic_fetch_sub(&(header_.weak_ref_count), 1, __ATOMIC_RELEASE) == 1) { + __atomic_thread_fence(__ATOMIC_ACQUIRE); + if (header_.deleter != nullptr) { + header_.deleter(&(this->header_), kTVMFFIObjectDeleterFlagBitMaskWeak); + } + } + } + } +#endif + } + + /*! \brief decrease weak reference count */ + void DecWeakRef() { +#ifdef _MSC_VER + if (_InterlockedDecrement( // + reinterpret_cast(&header_.weak_ref_count)) == 0) { // NOLINT(*) + if (header_.deleter != nullptr) { + header_.deleter(&(this->header_), kTVMFFIObjectDeleterFlagBitMaskWeak); + } + } +#else + // now decrease weak reference count + if (__atomic_fetch_sub(&(header_.weak_ref_count), 1, __ATOMIC_RELEASE) == 1) { __atomic_thread_fence(__ATOMIC_ACQUIRE); if (header_.deleter != nullptr) { - header_.deleter(&(this->header_)); + header_.deleter(&(this->header_), kTVMFFIObjectDeleterFlagBitMaskWeak); } } #endif @@ -265,6 +354,8 @@ class Object { // friend classes template friend class ObjectPtr; + template + friend class WeakObjectPtr; friend struct tvm::ffi::details::ObjectUnsafe; }; @@ -402,6 +493,148 @@ class ObjectPtr { friend struct ObjectPtrHash; template friend class ObjectPtr; + template + friend class WeakObjectPtr; + friend struct tvm::ffi::details::ObjectUnsafe; +}; + +/*! + * \brief A custom smart pointer for Object. + * \tparam T the content data type. + * \sa make_object + */ +template +class WeakObjectPtr { + public: + /*! \brief default constructor */ + WeakObjectPtr() {} + /*! \brief default constructor */ + WeakObjectPtr(std::nullptr_t) {} // NOLINT(*) + /*! + * \brief copy constructor + * \param other The value to be moved + */ + WeakObjectPtr(const WeakObjectPtr& other) // NOLINT(*) + : WeakObjectPtr(other.data_) {} + + /*! + * \brief copy constructor + * \param other The value to be moved + */ + WeakObjectPtr(const ObjectPtr& other) // NOLINT(*) + : WeakObjectPtr(other.get()) {} + /*! + * \brief copy constructor + * \param other The value to be moved + */ + template + WeakObjectPtr(const WeakObjectPtr& other) // NOLINT(*) + : WeakObjectPtr(other.data_) { + static_assert(std::is_base_of::value, + "can only assign of child class ObjectPtr to parent"); + } + /*! + * \brief copy constructor + * \param other The value to be moved + */ + template + WeakObjectPtr(const ObjectPtr& other) // NOLINT(*) + : WeakObjectPtr(other.data_) { + static_assert(std::is_base_of::value, + "can only assign of child class ObjectPtr to parent"); + } + /*! + * \brief move constructor + * \param other The value to be moved + */ + WeakObjectPtr(WeakObjectPtr&& other) // NOLINT(*) + : data_(other.data_) { + other.data_ = nullptr; + } + /*! + * \brief move constructor + * \param other The value to be moved + */ + template + WeakObjectPtr(WeakObjectPtr&& other) // NOLINT(*) + : data_(other.data_) { + static_assert(std::is_base_of::value, + "can only assign of child class ObjectPtr to parent"); + other.data_ = nullptr; + } + /*! \brief destructor */ + ~WeakObjectPtr() { this->reset(); } + /*! + * \brief Swap this array with another Object + * \param other The other Object + */ + void swap(WeakObjectPtr& other) { // NOLINT(*) + std::swap(data_, other.data_); + } + + /*! + * \brief copy assignment + * \param other The value to be assigned. + * \return reference to self. + */ + WeakObjectPtr& operator=(const WeakObjectPtr& other) { // NOLINT(*) + // takes in plane operator to enable copy elison. + // copy-and-swap idiom + WeakObjectPtr(other).swap(*this); // NOLINT(*) + return *this; + } + /*! + * \brief move assignment + * \param other The value to be assigned. + * \return reference to self. + */ + WeakObjectPtr& operator=(WeakObjectPtr&& other) { // NOLINT(*) + // copy-and-swap idiom + WeakObjectPtr(std::move(other)).swap(*this); // NOLINT(*) + return *this; + } + + /*! \return The internal object pointer if the object is still alive, otherwise nullptr */ + ObjectPtr lock() const { + if (data_ != nullptr && data_->TryPromoteWeakPtr()) { + ObjectPtr ret; + // we already increase the reference count, so we don't need to do it again + ret.data_ = data_; + return ret; + } + return nullptr; + } + + /*! \brief reset the content of ptr to be nullptr */ + void reset() { + if (data_ != nullptr) { + data_->DecWeakRef(); + data_ = nullptr; + } + } + + /*! \return The use count of the ptr, for debug purposes */ + int use_count() const { return data_ != nullptr ? data_->use_count() : 0; } + + /*! \return whether the pointer is nullptr */ + bool expired() const { return data_ == nullptr || data_->use_count() == 0; } + + private: + /*! \brief internal pointer field */ + Object* data_{nullptr}; + + /*! + * \brief constructor from Object + * \param data The data pointer + */ + explicit WeakObjectPtr(Object* data) : data_(data) { + if (data_ != nullptr) { + data_->IncWeakRef(); + } + } + + template + friend class WeakObjectPtr; friend struct tvm::ffi::details::ObjectUnsafe; }; diff --git a/ffi/include/tvm/ffi/type_traits.h b/ffi/include/tvm/ffi/type_traits.h index b019935a6cc8..9cdb2b933894 100644 --- a/ffi/include/tvm/ffi/type_traits.h +++ b/ffi/include/tvm/ffi/type_traits.h @@ -472,7 +472,7 @@ struct TypeTraits : public TypeTraitsBase { } else if (src->type_index == TypeIndex::kTVMFFINDArray) { // Conversion from NDArray pointer to DLTensor // based on the assumption that NDArray always follows the TVMFFIObject header - static_assert(sizeof(TVMFFIObject) == 16, "TVMFFIObject must be 8 bytes"); + static_assert(sizeof(TVMFFIObject) == 24); return reinterpret_cast(reinterpret_cast(src->v_obj) + sizeof(TVMFFIObject)); } diff --git a/ffi/pyproject.toml b/ffi/pyproject.toml index 8ed9e275e2b3..083a60fc3631 100644 --- a/ffi/pyproject.toml +++ b/ffi/pyproject.toml @@ -17,7 +17,7 @@ [project] name = "apache-tvm-ffi" -version = "0.1.0a5" +version = "0.1.0a6" description = "tvm ffi" authors = [{ name = "TVM FFI team" }] diff --git a/ffi/python/tvm_ffi/cython/base.pxi b/ffi/python/tvm_ffi/cython/base.pxi index 14b3d97f5260..4a47efd773d9 100644 --- a/ffi/python/tvm_ffi/cython/base.pxi +++ b/ffi/python/tvm_ffi/cython/base.pxi @@ -171,7 +171,7 @@ cdef extern from "tvm/ffi/c_api.h": const TVMFFIMethodInfo* methods const TVMFFITypeMetadata* metadata - int TVMFFIObjectFree(TVMFFIObjectHandle obj) nogil + int TVMFFIObjectDecRef(TVMFFIObjectHandle obj) nogil int TVMFFIObjectGetTypeIndex(TVMFFIObjectHandle obj) nogil int TVMFFIFunctionCall(TVMFFIObjectHandle func, TVMFFIAny* args, int32_t num_args, TVMFFIAny* result) nogil diff --git a/ffi/python/tvm_ffi/cython/dtype.pxi b/ffi/python/tvm_ffi/cython/dtype.pxi index 279b17f8c83c..d9e20b77f3a8 100644 --- a/ffi/python/tvm_ffi/cython/dtype.pxi +++ b/ffi/python/tvm_ffi/cython/dtype.pxi @@ -104,7 +104,7 @@ cdef class DataType: bytes_ptr = TVMFFIBytesGetByteArrayPtr(temp_any.v_obj) res = py_str(PyBytes_FromStringAndSize(bytes_ptr.data, bytes_ptr.size)) - CHECK_CALL(TVMFFIObjectFree(temp_any.v_obj)) + CHECK_CALL(TVMFFIObjectDecRef(temp_any.v_obj)) return res diff --git a/ffi/python/tvm_ffi/cython/object.pxi b/ffi/python/tvm_ffi/cython/object.pxi index dad6bee51b34..1203f0c68289 100644 --- a/ffi/python/tvm_ffi/cython/object.pxi +++ b/ffi/python/tvm_ffi/cython/object.pxi @@ -78,7 +78,7 @@ cdef class Object: def __dealloc__(self): if self.chandle != NULL: - CHECK_CALL(TVMFFIObjectFree(self.chandle)) + CHECK_CALL(TVMFFIObjectDecRef(self.chandle)) self.chandle = NULL def __ctypes_handle__(self): diff --git a/ffi/src/ffi/object.cc b/ffi/src/ffi/object.cc index 61107cb63ff7..f96636fd4994 100644 --- a/ffi/src/ffi/object.cc +++ b/ffi/src/ffi/object.cc @@ -388,12 +388,18 @@ class TypeTable { } // namespace ffi } // namespace tvm -int TVMFFIObjectFree(TVMFFIObjectHandle handle) { +int TVMFFIObjectDecRef(TVMFFIObjectHandle handle) { TVM_FFI_SAFE_CALL_BEGIN(); tvm::ffi::details::ObjectUnsafe::DecRefObjectHandle(handle); TVM_FFI_SAFE_CALL_END(); } +int TVMFFIObjectIncRef(TVMFFIObjectHandle handle) { + TVM_FFI_SAFE_CALL_BEGIN(); + tvm::ffi::details::ObjectUnsafe::IncRefObjectHandle(handle); + TVM_FFI_SAFE_CALL_END(); +} + int TVMFFITypeKeyToIndex(const TVMFFIByteArray* type_key, int32_t* out_tindex) { TVM_FFI_SAFE_CALL_BEGIN(); out_tindex[0] = tvm::ffi::TypeTable::Global()->TypeKeyToIndex(type_key); diff --git a/ffi/tests/cpp/test_c_ffi_abi.cc b/ffi/tests/cpp/test_c_ffi_abi.cc index 1efceef2971a..e6c6116edd8c 100644 --- a/ffi/tests/cpp/test_c_ffi_abi.cc +++ b/ffi/tests/cpp/test_c_ffi_abi.cc @@ -25,7 +25,7 @@ TEST(ABIHeaderAlignment, Default) { TVMFFIObject value; value.type_index = 10; EXPECT_EQ(reinterpret_cast(&value)->type_index, 10); - static_assert(sizeof(TVMFFIObject) == 16, "TVMFFIObject must be 16 bytes"); + static_assert(sizeof(TVMFFIObject) == 24); } } // namespace diff --git a/ffi/tests/cpp/test_object.cc b/ffi/tests/cpp/test_object.cc index 4b53a70b42a2..f6bedcb6f371 100644 --- a/ffi/tests/cpp/test_object.cc +++ b/ffi/tests/cpp/test_object.cc @@ -103,4 +103,123 @@ TEST(Object, CAPIAccessor) { int32_t type_index = TVMFFIObjectGetTypeIndex(obj); EXPECT_EQ(type_index, TIntObj::RuntimeTypeIndex()); } + +TEST(Object, WeakObjectPtr) { + // Test basic construction from ObjectPtr + ObjectPtr strong_ptr = make_object(42); + WeakObjectPtr weak_ptr(strong_ptr); + + EXPECT_EQ(strong_ptr.use_count(), 1); + EXPECT_FALSE(weak_ptr.expired()); + EXPECT_EQ(weak_ptr.use_count(), 1); + + // Test lock() when object is still alive + ObjectPtr locked_ptr = weak_ptr.lock(); + EXPECT_TRUE(locked_ptr != nullptr); + EXPECT_EQ(locked_ptr->value, 42); + EXPECT_EQ(strong_ptr.use_count(), 2); + EXPECT_EQ(weak_ptr.use_count(), 2); + + // Test lock() when object is expired + strong_ptr.reset(); + locked_ptr.reset(); + EXPECT_TRUE(weak_ptr.expired()); + EXPECT_EQ(weak_ptr.use_count(), 0); + + ObjectPtr expired_lock = weak_ptr.lock(); + EXPECT_TRUE(expired_lock == nullptr); +} + +TEST(Object, WeakObjectPtrAssignment) { + // Test copy construction + ObjectPtr new_strong = make_object(100); + WeakObjectPtr weak1(new_strong); + WeakObjectPtr weak2(weak1); + + EXPECT_EQ(new_strong.use_count(), 1); + EXPECT_FALSE(weak1.expired()); + EXPECT_FALSE(weak2.expired()); + EXPECT_EQ(weak1.use_count(), 1); + EXPECT_EQ(weak2.use_count(), 1); + + // Test move construction + WeakObjectPtr weak3(std::move(weak1)); + EXPECT_TRUE(weak1.expired()); // weak1 should be moved from + EXPECT_FALSE(weak3.expired()); + EXPECT_EQ(weak3.use_count(), 1); + + // Test assignment + WeakObjectPtr weak4; + weak4 = weak2; + EXPECT_FALSE(weak2.expired()); + EXPECT_FALSE(weak4.expired()); + EXPECT_EQ(weak2.use_count(), 1); + EXPECT_EQ(weak4.use_count(), 1); + + // Test move assignment + WeakObjectPtr weak5; + weak5 = std::move(weak2); + EXPECT_TRUE(weak2.expired()); // weak2 should be moved from + EXPECT_FALSE(weak5.expired()); + EXPECT_EQ(weak5.use_count(), 1); + + // Test reset() + weak3.reset(); + EXPECT_TRUE(weak3.expired()); + EXPECT_EQ(weak3.use_count(), 0); + + // Test swap() + ObjectPtr strong_a = make_object(200); + ObjectPtr strong_b = make_object(300); + WeakObjectPtr weak_a(strong_a); + WeakObjectPtr weak_b(strong_b); + + weak_a.swap(weak_b); + EXPECT_EQ(weak_a.lock()->value, 300); + EXPECT_EQ(weak_b.lock()->value, 200); + + // Test construction from nullptr + WeakObjectPtr null_weak(nullptr); + EXPECT_TRUE(null_weak.expired()); + EXPECT_EQ(null_weak.use_count(), 0); + EXPECT_TRUE(null_weak.lock() == nullptr); + + // Test inheritance compatibility + ObjectPtr number_ptr = make_object(500); + WeakObjectPtr number_weak(number_ptr); + + EXPECT_FALSE(number_weak.expired()); + EXPECT_EQ(number_weak.use_count(), 1); + + // Test that weak references don't prevent object deletion + ObjectPtr temp_strong = make_object(999); + WeakObjectPtr temp_weak(temp_strong); + + EXPECT_FALSE(temp_weak.expired()); + temp_strong.reset(); + EXPECT_TRUE(temp_weak.expired()); + EXPECT_TRUE(temp_weak.lock() == nullptr); + + // Test multiple weak references + ObjectPtr multi_strong = make_object(777); + WeakObjectPtr multi_weak1(multi_strong); + WeakObjectPtr multi_weak2(multi_strong); + WeakObjectPtr multi_weak3(multi_strong); + + EXPECT_EQ(multi_strong.use_count(), 1); + EXPECT_FALSE(multi_weak1.expired()); + EXPECT_FALSE(multi_weak2.expired()); + EXPECT_FALSE(multi_weak3.expired()); + + // All weak references should be able to lock + ObjectPtr lock1 = multi_weak1.lock(); + ObjectPtr lock2 = multi_weak2.lock(); + ObjectPtr lock3 = multi_weak3.lock(); + + EXPECT_EQ(multi_strong.use_count(), 4); + EXPECT_EQ(lock1->value, 777); + EXPECT_EQ(lock2->value, 777); + EXPECT_EQ(lock3->value, 777); +} + } // namespace diff --git a/jvm/native/src/main/native/jni_helper_func.h b/jvm/native/src/main/native/jni_helper_func.h index 5db3e279cf3f..9b50fb6a4914 100644 --- a/jvm/native/src/main/native/jni_helper_func.h +++ b/jvm/native/src/main/native/jni_helper_func.h @@ -236,7 +236,7 @@ jobject tvmRetValueToJava(JNIEnv* env, TVMFFIAny value) { } case TypeIndex::kTVMFFIBytes: { jobject ret = newTVMValueBytes(env, TVMFFIBytesGetByteArrayPtr(value.v_obj)); - TVMFFIObjectFree(value.v_obj); + TVMFFIObjectDecRef(value.v_obj); return ret; } default: { diff --git a/jvm/native/src/main/native/org_apache_tvm_native_c_api.cc b/jvm/native/src/main/native/org_apache_tvm_native_c_api.cc index 3ebe7fddfa8f..b512ec8775bd 100644 --- a/jvm/native/src/main/native/org_apache_tvm_native_c_api.cc +++ b/jvm/native/src/main/native/org_apache_tvm_native_c_api.cc @@ -322,7 +322,7 @@ JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmFFIFunctionSetGlobal(JNIEn // Module JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmFFIObjectFree(JNIEnv* env, jobject obj, jlong jhandle) { - return TVMFFIObjectFree(reinterpret_cast(jhandle)); + return TVMFFIObjectDecRef(reinterpret_cast(jhandle)); } // NDArray diff --git a/src/tir/transforms/make_packed_api.cc b/src/tir/transforms/make_packed_api.cc index 7477fe86363d..e6c6e9aa0275 100644 --- a/src/tir/transforms/make_packed_api.cc +++ b/src/tir/transforms/make_packed_api.cc @@ -299,10 +299,12 @@ PrimFunc MakePackedAPI(PrimFunc func) { tvm::tir::StringImm(msg.str()), nop)); // if type_index is NDArray, we need to add the offset of the DLTensor header // which always equals 16 bytes, this ensures that T.handle always shows up as a DLTensor* + const int64_t object_cell_offset = sizeof(TVMFFIObject); + static_assert(object_cell_offset == 24); arg_value = f_load_arg_value(param.dtype(), i); PrimExpr handle_from_ndarray = Call(DataType::Handle(), tir::builtin::handle_add_byte_offset(), - {arg_value, IntImm(DataType::Int(32), 16)}); + {arg_value, IntImm(DataType::Int(32), object_cell_offset)}); arg_value = Select(type_index == ffi::TypeIndex::kTVMFFINDArray, handle_from_ndarray, arg_value); } else if (dtype.is_bool()) { diff --git a/web/src/ctypes.ts b/web/src/ctypes.ts index d2ecf4b944b0..9836fbfda530 100644 --- a/web/src/ctypes.ts +++ b/web/src/ctypes.ts @@ -41,7 +41,7 @@ export const enum SizeOf { TVMFFIAny = 8 * 2, DLDataType = I32, DLDevice = I32 + I32, - ObjectHeader = 8 * 2, + ObjectHeader = 8 * 3, } //---------------The new TVM FFI--------------- @@ -142,9 +142,9 @@ export type FTVMFFIWasmFunctionCreate = ( export type FTVMFFIWasmFunctionDeleter = (self: Pointer) => void; /** - * int TVMFFIObjectFree(TVMFFIObjectHandle obj); + * int TVMFFIObjectDecRef(TVMFFIObjectHandle obj); */ -export type FTVMFFIObjectFree = (obj: Pointer) => number; +export type FTVMFFIObjectDecRef = (obj: Pointer) => number; /** * int TVMFFITypeKeyToIndex(const TVMFFIByteArray* type_key, int32_t* out_tindex); diff --git a/web/src/runtime.ts b/web/src/runtime.ts index 071b2eed68e4..3720b1873eee 100644 --- a/web/src/runtime.ts +++ b/web/src/runtime.ts @@ -450,7 +450,7 @@ export class TVMObject implements Disposable { dispose(): void { if (this.handle != 0) { this.lib.checkCall( - (this.lib.exports.TVMFFIObjectFree as ctypes.FTVMFFIObjectFree)(this.handle) + (this.lib.exports.TVMFFIObjectDecRef as ctypes.FTVMFFIObjectDecRef)(this.handle) ); this.handle = 0; } @@ -2253,7 +2253,7 @@ export class Instance implements Disposable { const strObjPtr = this.memory.loadPointer(valuePtr); const result = this.memory.loadByteArrayAsString(strObjPtr + SizeOf.ObjectHeader); this.lib.checkCall( - (this.lib.exports.TVMFFIObjectFree as ctypes.FTVMFFIObjectFree)(strObjPtr) + (this.lib.exports.TVMFFIObjectDecRef as ctypes.FTVMFFIObjectDecRef)(strObjPtr) ); return result; } @@ -2264,7 +2264,7 @@ export class Instance implements Disposable { const strObjPtr = this.memory.loadPointer(valuePtr); const result = this.memory.loadByteArrayAsString(strObjPtr + SizeOf.ObjectHeader); this.lib.checkCall( - (this.lib.exports.TVMFFIObjectFree as ctypes.FTVMFFIObjectFree)(strObjPtr) + (this.lib.exports.TVMFFIObjectDecRef as ctypes.FTVMFFIObjectDecRef)(strObjPtr) ); return result; } @@ -2275,7 +2275,7 @@ export class Instance implements Disposable { const bytesObjPtr = this.memory.loadPointer(valuePtr); const result = this.memory.loadByteArrayAsBytes(bytesObjPtr + SizeOf.ObjectHeader); this.lib.checkCall( - (this.lib.exports.TVMFFIObjectFree as ctypes.FTVMFFIObjectFree)(bytesObjPtr) + (this.lib.exports.TVMFFIObjectDecRef as ctypes.FTVMFFIObjectDecRef)(bytesObjPtr) ); return result; }