diff --git a/ffi/include/tvm/ffi/any.h b/ffi/include/tvm/ffi/any.h index 8274904037bc..4ec72d684659 100644 --- a/ffi/include/tvm/ffi/any.h +++ b/ffi/include/tvm/ffi/any.h @@ -425,6 +425,15 @@ struct AnyUnsafe : public ObjectUnsafe { } } + template + static TVM_FFI_INLINE T MoveFromAnyStorageAfterCheck(Any&& ref) { + if constexpr (!std::is_same_v) { + return TypeTraits::MoveFromAnyStorageAfterCheck(&(ref.data_)); + } else { + return std::move(ref); + } + } + static TVM_FFI_INLINE Object* ObjectPtrFromAnyAfterCheck(const Any& ref) { return reinterpret_cast(ref.data_.v_obj); } diff --git a/ffi/include/tvm/ffi/base_details.h b/ffi/include/tvm/ffi/base_details.h index 18cc3ecb726f..eeb892eff65e 100644 --- a/ffi/include/tvm/ffi/base_details.h +++ b/ffi/include/tvm/ffi/base_details.h @@ -123,9 +123,9 @@ * This macro is used to clear the padding parts for hash and equality check * in 32bit platform. */ -#define TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(result) \ - if constexpr (sizeof(result->v_obj) != sizeof(result->v_int64)) { \ - result->v_int64 = 0; \ +#define TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(result) \ + if constexpr (sizeof((result)->v_obj) != sizeof((result)->v_int64)) { \ + (result)->v_int64 = 0; \ } namespace tvm { diff --git a/ffi/include/tvm/ffi/container/container_details.h b/ffi/include/tvm/ffi/container/container_details.h index 51e130f37385..cfc5590f5404 100644 --- a/ffi/include/tvm/ffi/container/container_details.h +++ b/ffi/include/tvm/ffi/container/container_details.h @@ -284,6 +284,14 @@ inline constexpr bool storage_enabled_v = std::is_same_v || TypeTraits inline constexpr bool all_storage_enabled_v = (storage_enabled_v && ...); +/*! + * \brief Check if all T are compatible with Any. + * + * \tparam T The type to check. + * \return True if T is compatible with Any, false otherwise. + */ +template +inline constexpr bool all_object_ref_v = (std::is_base_of_v && ...); /** * \brief Check if Any storage of Derived can always be directly used as Base. * diff --git a/ffi/include/tvm/ffi/container/variant.h b/ffi/include/tvm/ffi/container/variant.h index f134be833193..c2b068890058 100644 --- a/ffi/include/tvm/ffi/container/variant.h +++ b/ffi/include/tvm/ffi/container/variant.h @@ -34,15 +34,73 @@ namespace tvm { namespace ffi { +namespace details { +/*! + * \brief Base class for Variant. + * + * \tparam all_storage_object Whether all types are derived from ObjectRef. + */ +template +class VariantBase { + public: + TVM_FFI_INLINE bool same_as(const VariantBase& other) const { + return data_.same_as(other.data_); + } + + protected: + template + explicit VariantBase(T other) : data_(std::move(other)) {} + + TVM_FFI_INLINE void SetData(Any other_data) { data_ = std::move(other_data); } + + TVM_FFI_INLINE Any MoveToAny() && { return std::move(data_); } + + TVM_FFI_INLINE AnyView ToAnyView() const { return data_.operator AnyView(); } + + Any data_; +}; + +// Specialization for all object ref case, backed by ObjectRef. +template <> +class VariantBase : public ObjectRef { + protected: + template + explicit VariantBase(const T& other) : ObjectRef(other) {} + template + explicit VariantBase(T&& other) : ObjectRef(std::move(other)) {} + explicit VariantBase(ObjectPtr ptr) : ObjectRef(ptr) {} + explicit VariantBase(Any other) + : ObjectRef(details::AnyUnsafe::MoveFromAnyStorageAfterCheck(std::move(other))) {} + + TVM_FFI_INLINE void SetData(ObjectPtr other) { data_ = std::move(other); } + + TVM_FFI_INLINE Any MoveToAny() && { return Any(ObjectRef(std::move(data_))); } + + TVM_FFI_INLINE AnyView ToAnyView() const { + TVMFFIAny any_data; + if (data_ == nullptr) { + any_data.type_index = TypeIndex::kTVMFFINone; + any_data.v_int64 = 0; + } else { + TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(&any_data); + any_data.type_index = data_->type_index(); + any_data.v_obj = details::ObjectUnsafe::TVMFFIObjectPtrFromObjectPtr(data_); + } + return AnyView::CopyFromTVMFFIAny(any_data); + } +}; +} // namespace details /*! * \brief A typed variant container. * - * A Variant is backed by Any container, with strong checks during construction. + * When all values are ObjectRef, Variant is backed by ObjectRef, + * otherwise it is backed by Any. */ template -class Variant { +class Variant : public details::VariantBase> { public: + using TParent = details::VariantBase>; static_assert(details::all_storage_enabled_v, "All types used in Variant<...> must be compatible with Any"); /* @@ -54,31 +112,30 @@ class Variant { template using enable_if_variant_contains_t = std::enable_if_t>; - Variant(const Variant& other) : data_(other.data_) {} - Variant(Variant&& other) : data_(std::move(other.data_)) {} + Variant(const Variant& other) : TParent(other.data_) {} + Variant(Variant&& other) : TParent(std::move(other.data_)) {} TVM_FFI_INLINE Variant& operator=(const Variant& other) { - data_ = other.data_; + this->SetData(other.data_); return *this; } TVM_FFI_INLINE Variant& operator=(Variant&& other) { - data_ = std::move(other.data_); + this->SetData(std::move(other.data_)); return *this; } template > - Variant(T other) : data_(std::move(other)) {} // NOLINT(*) + Variant(T other) : TParent(std::move(other)) {} // NOLINT(*) template > TVM_FFI_INLINE Variant& operator=(T other) { - data_ = std::move(other); - return *this; + return operator=(Variant(std::move(other))); } template > TVM_FFI_INLINE std::optional as() const { - return data_.as(); + return this->TParent::ToAnyView().template as(); } /* @@ -89,29 +146,27 @@ class Variant { */ template >> TVM_FFI_INLINE const T* as() const { - return data_.as().value_or(nullptr); + return this->TParent::ToAnyView().template as().value_or(nullptr); } template > TVM_FFI_INLINE T get() const& { - return data_.template cast(); + return this->TParent::ToAnyView().template cast(); } template > TVM_FFI_INLINE T get() && { - return std::move(data_).template cast(); + return std::move(*this).TParent::MoveToAny().template cast(); } - TVM_FFI_INLINE std::string GetTypeKey() const { return data_.GetTypeKey(); } + TVM_FFI_INLINE std::string GetTypeKey() const { return this->TParent::ToAnyView().GetTypeKey(); } private: friend struct TypeTraits>; friend struct ObjectPtrHash; friend struct ObjectPtrEqual; // constructor from any - explicit Variant(Any data) : data_(std::move(data)) {} - // internal data is backed by Any - Any data_; + explicit Variant(Any data) : TParent(std::move(data)) {} /*! * \brief Get the object pointer from the variant * \note This function is only available if all types used in Variant<...> are derived from @@ -122,8 +177,11 @@ class Variant { static_assert(all_object_v, "All types used in Variant<...> must be derived from ObjectRef " "to enable ObjectPtrHash/ObjectPtrEqual"); - return details::AnyUnsafe::ObjectPtrFromAnyAfterCheck(data_); + return this->data_.get(); } + // rexpose to friend class + using TParent::MoveToAny; + using TParent::ToAnyView; }; template @@ -132,11 +190,11 @@ inline constexpr bool use_default_type_traits_v> = false; template struct TypeTraits> : public TypeTraitsBase { static TVM_FFI_INLINE void CopyToAnyView(const Variant& src, TVMFFIAny* result) { - *result = AnyView(src.data_).CopyToTVMFFIAny(); + *result = src.ToAnyView().CopyToTVMFFIAny(); } static TVM_FFI_INLINE void MoveToAny(Variant src, TVMFFIAny* result) { - *result = details::AnyUnsafe::MoveAnyToTVMFFIAny(std::move(src.data_)); + *result = details::AnyUnsafe::MoveAnyToTVMFFIAny(std::move(src).MoveToAny()); } static TVM_FFI_INLINE std::string GetMismatchTypeInfo(const TVMFFIAny* src) { diff --git a/ffi/tests/cpp/test_any.cc b/ffi/tests/cpp/test_any.cc index 816ae28e0e9c..d84cc64ae4b2 100644 --- a/ffi/tests/cpp/test_any.cc +++ b/ffi/tests/cpp/test_any.cc @@ -337,6 +337,7 @@ TEST(Any, ObjectMove) { auto v0 = std::move(any1).cast(); EXPECT_EQ(v0->value, 3.14); EXPECT_EQ(v0.use_count(), 1); + EXPECT_TRUE(any1 == nullptr); } } // namespace diff --git a/ffi/tests/cpp/test_map.cc b/ffi/tests/cpp/test_map.cc index bd0b58b7c46e..b7c977fd344c 100644 --- a/ffi/tests/cpp/test_map.cc +++ b/ffi/tests/cpp/test_map.cc @@ -243,7 +243,7 @@ TEST(Map, AnyConvertCheck) { ::tvm::ffi::Error); } -TEST(Map, ffi::FunctionGetItem) { +TEST(Map, FunctionGetItem) { Function f = Function::FromTyped([](const MapObj* n, const Any& k) -> Any { return n->at(k); }, "map_get_item"); Map map{{"x", 1}, {"y", 2}}; diff --git a/ffi/tests/cpp/test_variant.cc b/ffi/tests/cpp/test_variant.cc index ee49ac75d15f..17a112908722 100644 --- a/ffi/tests/cpp/test_variant.cc +++ b/ffi/tests/cpp/test_variant.cc @@ -134,4 +134,31 @@ TEST(Variant, Upcast) { EXPECT_EQ(a1[0].get(), 1); } +TEST(Variant, AllObjectRef) { + Variant> v0 = TInt(1); + EXPECT_EQ(v0.get()->value, 1); + static_assert(std::is_base_of_v); + Any any0 = v0; + EXPECT_EQ(any0.cast()->value, 1); + auto v2 = any0.cast>>(); + EXPECT_TRUE(v0.same_as(v2)); + // assignment operator + v0 = Array({TInt(2), TInt(3)}); + EXPECT_EQ(v0.get>().size(), 2); + EXPECT_EQ(v0.get>()[0]->value, 2); + EXPECT_EQ(v0.get>()[1]->value, 3); + EXPECT_EQ(sizeof(v0), sizeof(ObjectRef)); +} + +TEST(Variant, PODSameAs) { + Variant v0 = 1; + Variant v1 = 1; + EXPECT_TRUE(v0.same_as(v1)); + String s = String("hello"); + v0 = s; + v1 = s; + EXPECT_TRUE(v0.same_as(v1)); + v1 = String("hello"); + EXPECT_TRUE(!v0.same_as(v1)); +} } // namespace