From 422eafdd961e23718afaf5541aeeb389ce453e7c Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 1 Sep 2023 03:28:04 +0000 Subject: [PATCH 1/3] [IR] Implemented Variant<...> container This commit introduces a new container, `Variant`, which is analogous to the `std::variant` introduced in C++17, the `enum` in Rust, or a tagged union in C. The `Variant` class is templated over the types that it may contain (e.g. `Variant`), where each type is a distinct option that can be stored within the container. `Variant` is implemented as a subclass of `ObjectRef` with no additional data members, similar to the implementation of `Optional`. It can be constructed from any of its contained types, and the contents can be inspected using the usual `my_object.as()` and `Downcast(my_object)` methods. This is intended to allow for drop-in replacement of `ObjectRef` with `Variant` in places that previously used a common base class. To ensure that each variant can be uniquely retrieved, no type stored within the variant may inherit from any other type within the variant. This condition is checked at compile-time, with a `static_assert` explaining the limitation. This condition is necessary to mimic the semantics of `std::variant`, whose active member depends on the compile-time type of an object. Without this condition, the expression `Variant variant = PrimExpr(...)` could populate either of the variants depending on the run-time type of an object. Because the `Variant` class is primarily intended for use when two types do not already inherit from each other, this limitation is not expected to limit its utility. There are several locations within the TVM codebase where this pattern may be useful, and which are currently worked around various strategies. (This PR does not alter any existing implementations, instead introducing the `Variant` container that can be used in subsequent PRs, if desired.) * Workaround: Store a common base class. For example, the type of `relax::TensorStructInfoNode::shape` is `Optional`, with a comment stating that it should be only `NullOpt`, `ShapeExpr`, or `Var`. However, these restrictions are not checked by the compiler, and a developer could erroneously provide a different type. By expressing the type as as `Optional>`, these errors could be automatically caught. * Workaround: Use additional data structures. For example, a `PrimFunc` parameter may be either a TIR primitive, which is lowered to a primitive type, or a TIR Buffer, which is lowered to a `DLTensor*` argument and appropriate unpacking code. However, these two types are represented as an `Array` and a `Map`, which together represent a `Array>`. The separate data structures must be kept in sync whenever modified, such as when removing a parameter. * Workaround: Use `std::variant`. For example, the `tvm::tir::IdentifyMemCpyImpl` utility function returns a `std::variant` with the result or an error message. However, this is only suitable for use within a C++ implementation, and requires a wrapper in order to expose it to the FFI. --- include/tvm/runtime/container/variant.h | 123 +++++++++++++++++++++ include/tvm/runtime/packed_func.h | 54 ++++++++- src/support/ffi_testing.cc | 12 ++ tests/cpp/container_test.cc | 21 ++++ tests/python/unittest/test_ir_container.py | 26 +++++ 5 files changed, 233 insertions(+), 3 deletions(-) create mode 100644 include/tvm/runtime/container/variant.h diff --git a/include/tvm/runtime/container/variant.h b/include/tvm/runtime/container/variant.h new file mode 100644 index 000000000000..752575d00622 --- /dev/null +++ b/include/tvm/runtime/container/variant.h @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/runtime/container/variant.h + * \brief Runtime Variant container types. + */ +#ifndef TVM_RUNTIME_CONTAINER_VARIANT_H_ +#define TVM_RUNTIME_CONTAINER_VARIANT_H_ + +#include + +#include +#include +#include + +namespace tvm { +namespace runtime { + +namespace detail { +template +constexpr bool parent_is_base_of_any = false; + +template +constexpr bool parent_is_base_of_any> = + ((std::is_base_of_v && !std::is_same_v) || ...); + +/* \brief Utility to check if any parent is a base class of any child + * + * The type-checking in Variant relies on all types being from + * independent types, such that `Object::IsInstance` is sufficient to + * determine which variant is populated. + * + * For example, suppose the illegal `Variant` + * were allowed (e.g. to represent either the defintion of a variable + * or the usage of a variable). If a function returned + * `tir::PrimExpr`, it could result in either variant being filled, as + * the underlying type at runtime could be a `tir::Var`. This + * behavior is different from `std::variant`, which determines the + * active variant based solely on the compile-time type, and could + * produce very unexpected results if the variants have different + * semantic interpretations. + */ +template +constexpr bool any_parent_is_base_of_any_child = false; + +template +constexpr bool any_parent_is_base_of_any_child, ChildTuple> = + (parent_is_base_of_any || ...); +} // namespace detail + +template +class Variant : public ObjectRef { + static constexpr bool all_inherit_from_objectref = (std::is_base_of_v && ...); + static_assert(all_inherit_from_objectref, + "All types used in Variant<...> must inherit from ObjectRef"); + + static constexpr bool a_variant_inherits_from_another_variant = + detail::any_parent_is_base_of_any_child, std::tuple>; + static_assert(!a_variant_inherits_from_another_variant, + "Due to implementation limitations, " + "no type stored in a tvm::runtime::Variant " + "may be a subclass of any other type " + "stored in the same variant."); + + public: + /* \brief Helper utility to check if the type is part of the variant */ + template + static constexpr bool is_variant = (std::is_same_v || ...); + + /* \brief Helper utility for SFINAE if the type is part of the variant */ + template + using enable_if_variant = std::enable_if_t>; + + template > + explicit Variant(T value) : ObjectRef(std::move(value)) {} + + template > + Variant& operator=(T value) { + ObjectRef::operator=(std::move(value)); + return *this; + } + + // These functions would normally be declared with the + // TVM_DEFINE_OBJECT_REF_METHODS macro. However, we need additional + // type-checking inside the ObjectPtr constructor. + using ContainerType = Object; + Variant() : ObjectRef() {} + explicit Variant(ObjectPtr node) : ObjectRef(node) { + CHECK(node == nullptr || (node->IsInstance() || ...)) + << "Variant<" + << static_cast( + (std::stringstream() << ... << V::ContainerType::_type_key)) + .str() + << "> cannot hold an object of type " << node->GetTypeKey(); + } + TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(Variant); +}; + +} // namespace runtime + +// expose the functions to the root namespace. +using runtime::Variant; + +} // namespace tvm + +#endif // TVM_RUNTIME_CONTAINER_VARIANT_H_ diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index 7aa8ef1ba7ff..caaaec364068 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -27,6 +27,7 @@ #include #include #include +#include #include #include #include @@ -680,9 +681,6 @@ class TVMArgValue : public TVMPODValue_ { } else if (type_code_ == kTVMStr) { return std::string(value_.v_str); } else { - ICHECK(IsObjectRef()) - << "Could not convert TVM object of type " << runtime::Object::TypeIndex2Key(type_code_) - << " to a string."; return AsObjectRef().operator std::string(); } } @@ -2063,6 +2061,56 @@ struct PackedFuncValueConverter> { } }; +template +struct PackedFuncValueConverter> { + using VType = Variant; + + // Can't just take `const TVMPODValue&` as an argument, because + // `TVMArgValue` and `TVMRetValue` have different implementations + // for `operator std::string()`. + template + static VType From(const PODSubclass& val) { + if (auto opt = TryAsObjectRef(val)) { + return opt.value(); + } + + if (auto opt = TryValueConverter(val)) { + return opt.value(); + } + + LOG(FATAL) << "Expected one of " + << static_cast( + (std::stringstream() << ... << VariantTypes::ContainerType::_type_key)) + .str() + << " but got " << ArgTypeCode2Str(val.type_code()); + } + + template + static Optional TryAsObjectRef(const TVMPODValue_& val) { + if (val.IsObjectRef()) { + return VType(val.AsObjectRef()); + } else if constexpr (sizeof...(VarRest)) { + return TryAsObjectRef(val); + } else { + return NullOpt; + } + } + + template + static Optional TryValueConverter(const PODSubclass& val) { + try { + return VType(PackedFuncValueConverter::From(val)); + } catch (const InternalError&) { + } + + if constexpr (sizeof...(VarRest)) { + return TryValueConverter(val); + } else { + return NullOpt; + } + } +}; + inline bool String::CanConvertFrom(const TVMArgValue& val) { return val.type_code() == kTVMStr || val.IsObjectRef(); } diff --git a/src/support/ffi_testing.cc b/src/support/ffi_testing.cc index 6e7dec4cb776..690794fa61d1 100644 --- a/src/support/ffi_testing.cc +++ b/src/support/ffi_testing.cc @@ -23,6 +23,7 @@ */ #include #include +#include #include #include #include @@ -165,4 +166,15 @@ TVM_REGISTER_GLOBAL("testing.sleep_in_ffi").set_body_typed([](double timeout) { std::this_thread::sleep_for(duration); }); +TVM_REGISTER_GLOBAL("testing.ReturnsVariant").set_body_typed([](int x) -> Variant { + if (x % 2 == 0) { + return IntImm(DataType::Int(64), x / 2); + } else { + return String("argument was odd"); + } +}); + +TVM_REGISTER_GLOBAL("testing.AcceptsVariant") + .set_body_typed([](Variant arg) -> String { return arg->GetTypeKey(); }); + } // namespace tvm diff --git a/tests/cpp/container_test.cc b/tests/cpp/container_test.cc index d75a510d0c95..217c7fc2d4b3 100644 --- a/tests/cpp/container_test.cc +++ b/tests/cpp/container_test.cc @@ -23,6 +23,7 @@ #include #include #include +#include #include #include @@ -853,3 +854,23 @@ TEST(Optional, PackedCall) { test_ffi(s, static_cast(kTVMObjectHandle)); test_ffi(String(s), static_cast(kTVMObjectRValueRefArg)); } + +TEST(Variant, Construct) { + Variant variant; + variant = PrimExpr(1); + ICHECK(variant.as()); + ICHECK(!variant.as()); + + variant = String("hello"); + ICHECK(variant.as()); + ICHECK(!variant.as()); +} + +TEST(Variant, InvalidTypeThrowsError) { + auto expected_to_throw = []() { + ObjectPtr node = make_object(); + Variant variant(node); + }; + + EXPECT_THROW(expected_to_throw(), InternalError); +} diff --git a/tests/python/unittest/test_ir_container.py b/tests/python/unittest/test_ir_container.py index 1915849e1044..aa482dd65cd7 100644 --- a/tests/python/unittest/test_ir_container.py +++ b/tests/python/unittest/test_ir_container.py @@ -112,5 +112,31 @@ def test_ndarray_container(): assert isinstance(arr[0], tvm.nd.NDArray) +def test_return_variant_type(): + func = tvm.get_global_func("testing.ReturnsVariant") + res_even = func(42) + assert isinstance(res_even, tvm.tir.IntImm) + assert res_even == 21 + + res_odd = func(17) + assert isinstance(res_odd, tvm.runtime.String) + assert res_odd == "argument was odd" + + +def test_pass_variant_type(): + func = tvm.get_global_func("testing.AcceptsVariant") + + assert func("string arg") == "runtime.String" + assert func(17) == "IntImm" + + +def test_pass_incorrect_variant_type(): + func = tvm.get_global_func("testing.AcceptsVariant") + float_arg = tvm.tir.FloatImm("float32", 0.5) + + with pytest.raises(Exception): + func(float_arg) + + if __name__ == "__main__": tvm.testing.main() From e04ee4b8b201edb7e2645b48fdaacc80bfcb64d1 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 6 Sep 2023 17:48:16 -0500 Subject: [PATCH 2/3] Avoid ODR, lint errors for conversion to Variant --- include/tvm/runtime/container/variant.h | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/include/tvm/runtime/container/variant.h b/include/tvm/runtime/container/variant.h index 752575d00622..7953ac47c1cf 100644 --- a/include/tvm/runtime/container/variant.h +++ b/include/tvm/runtime/container/variant.h @@ -58,10 +58,10 @@ constexpr bool parent_is_base_of_any> = * semantic interpretations. */ template -constexpr bool any_parent_is_base_of_any_child = false; +static constexpr bool any_parent_is_base_of_any_child = false; template -constexpr bool any_parent_is_base_of_any_child, ChildTuple> = +static constexpr bool any_parent_is_base_of_any_child, ChildTuple> = (parent_is_base_of_any || ...); } // namespace detail @@ -89,7 +89,7 @@ class Variant : public ObjectRef { using enable_if_variant = std::enable_if_t>; template > - explicit Variant(T value) : ObjectRef(std::move(value)) {} + Variant(T value) : ObjectRef(std::move(value)) {} // NOLINT(*) template > Variant& operator=(T value) { From eb465caecb9816c075f0a32fdce0cbe7a1f9790f Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 12 Sep 2023 12:15:14 -0500 Subject: [PATCH 3/3] Added more C++ functionality tests. --- tests/cpp/container_test.cc | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/tests/cpp/container_test.cc b/tests/cpp/container_test.cc index 217c7fc2d4b3..5c9af19f9bc9 100644 --- a/tests/cpp/container_test.cc +++ b/tests/cpp/container_test.cc @@ -874,3 +874,29 @@ TEST(Variant, InvalidTypeThrowsError) { EXPECT_THROW(expected_to_throw(), InternalError); } + +TEST(Variant, ReferenceIdentifyPreservedThroughAssignment) { + Variant variant; + ICHECK(!variant.defined()); + + String string_obj = "dummy_test"; + variant = string_obj; + ICHECK(variant.defined()); + ICHECK(variant.same_as(string_obj)); + ICHECK(string_obj.same_as(variant)); + + String out_string_obj = Downcast(variant); + ICHECK(string_obj.same_as(out_string_obj)); +} + +TEST(Variant, ExtractValueFromAssignment) { + Variant variant = String("hello"); + ICHECK_EQ(variant.as().value(), "hello"); +} + +TEST(Variant, AssignmentFromVariant) { + Variant variant = String("hello"); + auto variant2 = variant; + ICHECK(variant2.as()); + ICHECK_EQ(variant2.as().value(), "hello"); +}