diff --git a/include/tvm/runtime/container/variant.h b/include/tvm/runtime/container/variant.h new file mode 100644 index 000000000000..7953ac47c1cf --- /dev/null +++ b/include/tvm/runtime/container/variant.h @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/runtime/container/variant.h + * \brief Runtime Variant container types. + */ +#ifndef TVM_RUNTIME_CONTAINER_VARIANT_H_ +#define TVM_RUNTIME_CONTAINER_VARIANT_H_ + +#include + +#include +#include +#include + +namespace tvm { +namespace runtime { + +namespace detail { +template +constexpr bool parent_is_base_of_any = false; + +template +constexpr bool parent_is_base_of_any> = + ((std::is_base_of_v && !std::is_same_v) || ...); + +/* \brief Utility to check if any parent is a base class of any child + * + * The type-checking in Variant relies on all types being from + * independent types, such that `Object::IsInstance` is sufficient to + * determine which variant is populated. + * + * For example, suppose the illegal `Variant` + * were allowed (e.g. to represent either the defintion of a variable + * or the usage of a variable). If a function returned + * `tir::PrimExpr`, it could result in either variant being filled, as + * the underlying type at runtime could be a `tir::Var`. This + * behavior is different from `std::variant`, which determines the + * active variant based solely on the compile-time type, and could + * produce very unexpected results if the variants have different + * semantic interpretations. + */ +template +static constexpr bool any_parent_is_base_of_any_child = false; + +template +static constexpr bool any_parent_is_base_of_any_child, ChildTuple> = + (parent_is_base_of_any || ...); +} // namespace detail + +template +class Variant : public ObjectRef { + static constexpr bool all_inherit_from_objectref = (std::is_base_of_v && ...); + static_assert(all_inherit_from_objectref, + "All types used in Variant<...> must inherit from ObjectRef"); + + static constexpr bool a_variant_inherits_from_another_variant = + detail::any_parent_is_base_of_any_child, std::tuple>; + static_assert(!a_variant_inherits_from_another_variant, + "Due to implementation limitations, " + "no type stored in a tvm::runtime::Variant " + "may be a subclass of any other type " + "stored in the same variant."); + + public: + /* \brief Helper utility to check if the type is part of the variant */ + template + static constexpr bool is_variant = (std::is_same_v || ...); + + /* \brief Helper utility for SFINAE if the type is part of the variant */ + template + using enable_if_variant = std::enable_if_t>; + + template > + Variant(T value) : ObjectRef(std::move(value)) {} // NOLINT(*) + + template > + Variant& operator=(T value) { + ObjectRef::operator=(std::move(value)); + return *this; + } + + // These functions would normally be declared with the + // TVM_DEFINE_OBJECT_REF_METHODS macro. However, we need additional + // type-checking inside the ObjectPtr constructor. + using ContainerType = Object; + Variant() : ObjectRef() {} + explicit Variant(ObjectPtr node) : ObjectRef(node) { + CHECK(node == nullptr || (node->IsInstance() || ...)) + << "Variant<" + << static_cast( + (std::stringstream() << ... << V::ContainerType::_type_key)) + .str() + << "> cannot hold an object of type " << node->GetTypeKey(); + } + TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(Variant); +}; + +} // namespace runtime + +// expose the functions to the root namespace. +using runtime::Variant; + +} // namespace tvm + +#endif // TVM_RUNTIME_CONTAINER_VARIANT_H_ diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index 7aa8ef1ba7ff..caaaec364068 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -27,6 +27,7 @@ #include #include #include +#include #include #include #include @@ -680,9 +681,6 @@ class TVMArgValue : public TVMPODValue_ { } else if (type_code_ == kTVMStr) { return std::string(value_.v_str); } else { - ICHECK(IsObjectRef()) - << "Could not convert TVM object of type " << runtime::Object::TypeIndex2Key(type_code_) - << " to a string."; return AsObjectRef().operator std::string(); } } @@ -2063,6 +2061,56 @@ struct PackedFuncValueConverter> { } }; +template +struct PackedFuncValueConverter> { + using VType = Variant; + + // Can't just take `const TVMPODValue&` as an argument, because + // `TVMArgValue` and `TVMRetValue` have different implementations + // for `operator std::string()`. + template + static VType From(const PODSubclass& val) { + if (auto opt = TryAsObjectRef(val)) { + return opt.value(); + } + + if (auto opt = TryValueConverter(val)) { + return opt.value(); + } + + LOG(FATAL) << "Expected one of " + << static_cast( + (std::stringstream() << ... << VariantTypes::ContainerType::_type_key)) + .str() + << " but got " << ArgTypeCode2Str(val.type_code()); + } + + template + static Optional TryAsObjectRef(const TVMPODValue_& val) { + if (val.IsObjectRef()) { + return VType(val.AsObjectRef()); + } else if constexpr (sizeof...(VarRest)) { + return TryAsObjectRef(val); + } else { + return NullOpt; + } + } + + template + static Optional TryValueConverter(const PODSubclass& val) { + try { + return VType(PackedFuncValueConverter::From(val)); + } catch (const InternalError&) { + } + + if constexpr (sizeof...(VarRest)) { + return TryValueConverter(val); + } else { + return NullOpt; + } + } +}; + inline bool String::CanConvertFrom(const TVMArgValue& val) { return val.type_code() == kTVMStr || val.IsObjectRef(); } diff --git a/src/support/ffi_testing.cc b/src/support/ffi_testing.cc index 6e7dec4cb776..690794fa61d1 100644 --- a/src/support/ffi_testing.cc +++ b/src/support/ffi_testing.cc @@ -23,6 +23,7 @@ */ #include #include +#include #include #include #include @@ -165,4 +166,15 @@ TVM_REGISTER_GLOBAL("testing.sleep_in_ffi").set_body_typed([](double timeout) { std::this_thread::sleep_for(duration); }); +TVM_REGISTER_GLOBAL("testing.ReturnsVariant").set_body_typed([](int x) -> Variant { + if (x % 2 == 0) { + return IntImm(DataType::Int(64), x / 2); + } else { + return String("argument was odd"); + } +}); + +TVM_REGISTER_GLOBAL("testing.AcceptsVariant") + .set_body_typed([](Variant arg) -> String { return arg->GetTypeKey(); }); + } // namespace tvm diff --git a/tests/cpp/container_test.cc b/tests/cpp/container_test.cc index d75a510d0c95..5c9af19f9bc9 100644 --- a/tests/cpp/container_test.cc +++ b/tests/cpp/container_test.cc @@ -23,6 +23,7 @@ #include #include #include +#include #include #include @@ -853,3 +854,49 @@ TEST(Optional, PackedCall) { test_ffi(s, static_cast(kTVMObjectHandle)); test_ffi(String(s), static_cast(kTVMObjectRValueRefArg)); } + +TEST(Variant, Construct) { + Variant variant; + variant = PrimExpr(1); + ICHECK(variant.as()); + ICHECK(!variant.as()); + + variant = String("hello"); + ICHECK(variant.as()); + ICHECK(!variant.as()); +} + +TEST(Variant, InvalidTypeThrowsError) { + auto expected_to_throw = []() { + ObjectPtr node = make_object(); + Variant variant(node); + }; + + EXPECT_THROW(expected_to_throw(), InternalError); +} + +TEST(Variant, ReferenceIdentifyPreservedThroughAssignment) { + Variant variant; + ICHECK(!variant.defined()); + + String string_obj = "dummy_test"; + variant = string_obj; + ICHECK(variant.defined()); + ICHECK(variant.same_as(string_obj)); + ICHECK(string_obj.same_as(variant)); + + String out_string_obj = Downcast(variant); + ICHECK(string_obj.same_as(out_string_obj)); +} + +TEST(Variant, ExtractValueFromAssignment) { + Variant variant = String("hello"); + ICHECK_EQ(variant.as().value(), "hello"); +} + +TEST(Variant, AssignmentFromVariant) { + Variant variant = String("hello"); + auto variant2 = variant; + ICHECK(variant2.as()); + ICHECK_EQ(variant2.as().value(), "hello"); +} diff --git a/tests/python/unittest/test_ir_container.py b/tests/python/unittest/test_ir_container.py index 1915849e1044..aa482dd65cd7 100644 --- a/tests/python/unittest/test_ir_container.py +++ b/tests/python/unittest/test_ir_container.py @@ -112,5 +112,31 @@ def test_ndarray_container(): assert isinstance(arr[0], tvm.nd.NDArray) +def test_return_variant_type(): + func = tvm.get_global_func("testing.ReturnsVariant") + res_even = func(42) + assert isinstance(res_even, tvm.tir.IntImm) + assert res_even == 21 + + res_odd = func(17) + assert isinstance(res_odd, tvm.runtime.String) + assert res_odd == "argument was odd" + + +def test_pass_variant_type(): + func = tvm.get_global_func("testing.AcceptsVariant") + + assert func("string arg") == "runtime.String" + assert func(17) == "IntImm" + + +def test_pass_incorrect_variant_type(): + func = tvm.get_global_func("testing.AcceptsVariant") + float_arg = tvm.tir.FloatImm("float32", 0.5) + + with pytest.raises(Exception): + func(float_arg) + + if __name__ == "__main__": tvm.testing.main()