From d50abb86ab5bfcbde517e368d1d10cbf992502b1 Mon Sep 17 00:00:00 2001 From: Greg Bonik Date: Mon, 18 Jul 2022 09:23:31 -0700 Subject: [PATCH 1/2] [TVMScript] TracedObject class that simplifies tracing ObjectPaths --- src/script/printer/traced_object.h | 410 +++++++++++++++++++++++++++++ tests/cpp/traced_object_test.cc | 268 +++++++++++++++++++ 2 files changed, 678 insertions(+) create mode 100644 src/script/printer/traced_object.h create mode 100644 tests/cpp/traced_object_test.cc diff --git a/src/script/printer/traced_object.h b/src/script/printer/traced_object.h new file mode 100644 index 000000000000..aa07bf7284fc --- /dev/null +++ b/src/script/printer/traced_object.h @@ -0,0 +1,410 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/script/printer/traced_object.h + * Wrappers around TVM objects that also store an ObjectPath from some "root" object + * to the wrapper object. + */ + +#ifndef TVM_SCRIPT_PRINTER_TRACED_OBJECT_H_ +#define TVM_SCRIPT_PRINTER_TRACED_OBJECT_H_ + +#include +#include +#include + +#include +#include + +namespace tvm { + +template +class TracedObject; +template +class TracedMap; +template +class TracedArray; +template +class TracedOptional; +template +class TracedBasicValue; + +namespace detail { + +template ::value> +struct TracedObjectWrapperSelector; + +template +struct TracedObjectWrapperSelector { + using Type = TracedBasicValue; +}; + +template +struct TracedObjectWrapperSelector { + using Type = TracedObject; +}; + +template +struct TracedObjectWrapperSelector, true> { + using Type = TracedMap; +}; + +template +struct TracedObjectWrapperSelector, true> { + using Type = TracedArray; +}; + +template +struct TracedObjectWrapperSelector, true> { + using Type = TracedOptional; +}; + +} // namespace detail + +/*! + * \brief Traced wrapper for regular (non-container) TVM objects. + */ +template +class TracedObject { + using ObjectType = typename RefT::ContainerType; + + public: + // Don't use this direcly. For convenience, call MakeTraced() instead. + explicit TracedObject(const RefT& object_ref, ObjectPath path) + : ref_(object_ref), path_(std::move(path)) {} + + // Implicit conversion from a derived reference class + template + TracedObject(const TracedObject& derived) + : ref_(derived.Get()), path_(derived.GetPath()) {} + + /*! + * \brief Get a traced wrapper for an attribute of the wrapped object. + */ + template + typename detail::TracedObjectWrapperSelector::Type GetAttr(T BaseType::*member_ptr) const { + using WrapperType = typename detail::TracedObjectWrapperSelector::Type; + const ObjectType* node = static_cast(ref_.get()); + const T& attr = node->*member_ptr; + Optional attr_key = ICHECK_NOTNULL(GetAttrKeyByAddress(node, &attr)); + return WrapperType(attr, path_->Attr(attr_key)); + } + + /*! + * \brief Access the wrapped object. + */ + const RefT& Get() const { return ref_; } + + /*! + * \brief Check if the reference to the wrapped object can be converted to `RefU`. + */ + template + bool IsInstance() const { + return ref_->template IsInstance(); + } + + /*! + * \brief Same as Get().defined(). + */ + bool defined() const { return ref_.defined(); } + + /*! + * \brief Convert the wrapped reference type to a subtype. + * + * Throws an exception if IsInstance() is false. + */ + template + TracedObject Downcast() const { + return TracedObject(tvm::runtime::Downcast(ref_), path_); + } + + /*! + * \brief Convert the wrapped reference type to a subtype. + * + * Returns an empty optional if IsInstance() is false. + */ + template + TracedOptional TryDowncast() const { + if (ref_->template IsInstance()) { + return Downcast(); + } else { + return TracedOptional(NullOpt, path_); + } + } + + /*! + * \brief Get the path of the wrapped object. + */ + const ObjectPath& GetPath() const { return path_; } + + private: + RefT ref_; + ObjectPath path_; +}; + +template +class TracedMapIterator { + public: + using WrappedV = typename detail::TracedObjectWrapperSelector::Type; + using MapIter = typename Map::iterator; + + using iterator_category = std::bidirectional_iterator_tag; + using difference_type = ptrdiff_t; + using value_type = const std::pair; + using pointer = value_type*; + using reference = value_type; + + explicit TracedMapIterator(MapIter iter, ObjectPath map_path) + : iter_(iter), map_path_(std::move(map_path)) {} + + bool operator==(const TracedMapIterator& other) const { return iter_ == other.iter_; } + + bool operator!=(const TracedMapIterator& other) const { return iter_ != other.iter_; } + + pointer operator->() const = delete; + + reference operator*() const { + auto kv = *iter_; + return std::make_pair(kv.first, WrappedV(kv.second, map_path_->MapValue(kv.first))); + } + + TracedMapIterator& operator++() { + ++iter_; + return *this; + } + + TracedMapIterator operator++(int) { + TracedMapIterator copy = *this; + ++(*this); + return copy; + } + + private: + MapIter iter_; + ObjectPath map_path_; +}; + +/*! + * \brief Traced wrapper for Map objects. + */ +template +class TracedMap { + public: + using WrappedV = typename detail::TracedObjectWrapperSelector::Type; + + using iterator = TracedMapIterator; + + // Don't use this direcly. For convenience, call MakeTraced() instead. + explicit TracedMap(Map map, ObjectPath path) + : map_(std::move(map)), path_(std::move(path)) {} + + WrappedV at(const K& key) const { + auto it = map_.find(key); + ICHECK(it != map_.end()) << "No such key in Map"; + auto kv = *it; + return WrappedV(kv.second, path_->MapValue(kv.first)); + } + + const Map& Get() const { return map_; } + + const ObjectPath& GetPath() const { return path_; } + + iterator begin() const { return iterator(map_.begin(), path_); } + + iterator end() const { return iterator(map_.end(), path_); } + + bool empty() const { return map_.empty(); } + + private: + Map map_; + ObjectPath path_; +}; + +template +class TracedArrayIterator { + public: + using WrappedT = typename detail::TracedObjectWrapperSelector::Type; + + using difference_type = ptrdiff_t; + using value_type = WrappedT; + using pointer = WrappedT*; + using reference = WrappedT&; + using iterator_category = std::random_access_iterator_tag; + + explicit TracedArrayIterator(Array array, size_t index, ObjectPath array_path) + : array_(array), index_(index), array_path_(array_path) {} + + TracedArrayIterator& operator++() { + ++index_; + return *this; + } + TracedArrayIterator& operator--() { + --index_; + return *this; + } + TracedArrayIterator operator++(int) { + TracedArrayIterator copy = *this; + ++index_; + return copy; + } + TracedArrayIterator operator--(int) { + TracedArrayIterator copy = *this; + --index_; + return copy; + } + + TracedArrayIterator operator+(difference_type offset) const { + return TracedArrayIterator(array_, index_ + offset, array_path_); + } + + TracedArrayIterator operator-(difference_type offset) const { + return TracedArrayIterator(array_, index_ - offset, array_path_); + } + + difference_type operator-(const TracedArrayIterator& rhs) const { return index_ - rhs.index_; } + + bool operator==(TracedArrayIterator other) const { + return array_.get() == other.array_.get() && index_ == other.index_; + } + bool operator!=(TracedArrayIterator other) const { return !(*this == other); } + value_type operator*() const { return WrappedT(array_[index_], array_path_->ArrayIndex(index_)); } + + bool empty() const { return array_.empty(); } + + private: + Array array_; + size_t index_; + ObjectPath array_path_; +}; + +/*! + * \brief Traced wrapper for Array objects. + */ +template +class TracedArray { + public: + using WrappedT = typename detail::TracedObjectWrapperSelector::Type; + + using iterator = TracedArrayIterator; + + // Don't use this direcly. For convenience, call MakeTraced() instead. + explicit TracedArray(Array array, ObjectPath path) + : array_(std::move(array)), path_(std::move(path)) {} + + const Array& Get() const { return array_; } + + const ObjectPath& GetPath() const { return path_; } + + WrappedT operator[](size_t index) const { + return WrappedT(array_[index], path_->ArrayIndex(index)); + } + + iterator begin() const { return iterator(array_, 0, path_); } + + iterator end() const { return iterator(array_, array_.size(), path_); } + + bool empty() const { return array_.empty(); } + + size_t size() const { return array_.size(); } + + private: + Array array_; + ObjectPath path_; +}; + +/*! + * \brief Traced wrapper for Optional objects. + */ +template +class TracedOptional { + public: + using WrappedT = typename detail::TracedObjectWrapperSelector::Type; + + TracedOptional(const WrappedT& value) // NOLINT(runtime/explicit) + : optional_(value.Get().defined() ? value.Get() : Optional(NullOpt)), + path_(value.GetPath()) {} + + explicit TracedOptional(Optional optional, ObjectPath path) + : optional_(std::move(optional)), path_(std::move(path)) {} + + const Optional& Get() const { return optional_; } + + const ObjectPath& GetPath() const { return path_; } + + bool defined() const { return optional_.defined(); } + + WrappedT value() const { return WrappedT(optional_.value(), path_); } + + explicit operator bool() const { return optional_.defined(); } + + private: + Optional optional_; + ObjectPath path_; +}; + +/*! + * \brief Traced wrapper for basic values (i.e. non-TVM objects) + */ +template +class TracedBasicValue { + public: + explicit TracedBasicValue(const T& value, ObjectPath path) + : value_(value), path_(std::move(path)) {} + + const T& Get() const { return value_; } + + const ObjectPath& GetPath() const { return path_; } + + /*! + * \brief Transform the wrapped value without changing its path. + */ + template + typename detail::TracedObjectWrapperSelector::type>::Type + ApplyFunc(F&& f) const { + return MakeTraced(f(value_), path_); + } + + private: + T value_; + ObjectPath path_; +}; + +/*! + * \brief Wrap the given root object in an appropriate traced wrapper class. + */ +template +typename detail::TracedObjectWrapperSelector::Type MakeTraced(const RefT& object) { + using WrappedT = typename detail::TracedObjectWrapperSelector::Type; + return WrappedT(object, ObjectPath::Root()); +} + +/*! + * \brief Wrap the given object with the given path in an appropriate traced wrapper class. + */ +template +typename detail::TracedObjectWrapperSelector::Type MakeTraced(const RefT& object, + ObjectPath path) { + using WrappedT = typename detail::TracedObjectWrapperSelector::Type; + return WrappedT(object, std::move(path)); +} + +} // namespace tvm + +#endif // TVM_SCRIPT_PRINTER_TRACED_OBJECT_H_ diff --git a/tests/cpp/traced_object_test.cc b/tests/cpp/traced_object_test.cc new file mode 100644 index 000000000000..1e07919dc259 --- /dev/null +++ b/tests/cpp/traced_object_test.cc @@ -0,0 +1,268 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include <../../src/script/printer/traced_object.h> +#include +#include +#include +#include + +using namespace tvm; + +namespace { + +class DummyObjectNode : public Object { + public: + void VisitAttrs(AttrVisitor* v) {} + + static constexpr const char* _type_key = "TracedObjectTestDummyObject"; + TVM_DECLARE_FINAL_OBJECT_INFO(DummyObjectNode, Object); +}; + +class DummyObject : public ObjectRef { + public: + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(DummyObject, ObjectRef, DummyObjectNode); +}; + +TVM_REGISTER_NODE_TYPE(DummyObjectNode); + +class ObjectWithAttrsNode : public Object { + public: + int64_t int64_attr = 5; + Map map_attr; + Array array_attr; + DummyObject obj_attr; + + ObjectWithAttrsNode() : obj_attr(make_object()) {} + + void VisitAttrs(AttrVisitor* v) { + v->Visit("int64_attr", &int64_attr); + v->Visit("map_attr", &map_attr); + v->Visit("array_attr", &array_attr); + v->Visit("obj_attr", &obj_attr); + } + + static constexpr const char* _type_key = "TracedObjectTestObjectWithAttrs"; + TVM_DECLARE_FINAL_OBJECT_INFO(ObjectWithAttrsNode, Object); +}; + +class ObjectWithAttrs : public ObjectRef { + public: + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(ObjectWithAttrs, ObjectRef, ObjectWithAttrsNode); +}; + +TVM_REGISTER_NODE_TYPE(ObjectWithAttrsNode); + +} // anonymous namespace + +TEST(TracedObjectTest, MakeTraced_RootObject) { + ObjectWithAttrs root(make_object()); + auto root_traced = MakeTraced(root); + + static_assert(std::is_same>::value); + ICHECK(root_traced.GetPath()->PathsEqual(ObjectPath::Root())); + ICHECK_EQ(root_traced.Get().get(), root.get()); +} + +TEST(TracedObjectTest, MakeTraced_WithPath) { + ObjectWithAttrs obj(make_object()); + auto traced = MakeTraced(obj, ObjectPath::Root()->Attr("foo")); + + static_assert(std::is_same>::value); + ICHECK(traced.GetPath()->PathsEqual(ObjectPath::Root()->Attr("foo"))); + ICHECK_EQ(traced.Get().get(), obj.get()); +} + +TEST(TracedObjectTest, TracedObject_ImplicitConversionFromDerived) { + DummyObject obj(make_object()); + auto traced = MakeTraced(obj); + static_assert(std::is_same>::value); + + // Check that TracedObject is implicitly converted to TracedObject + auto base_traced = [](const TracedObject& base) { return base; }(traced); + + static_assert(std::is_same>::value); +} + +TEST(TracedObjectTest, TracedObject_GetAttr_ObjectRef) { + ObjectWithAttrs root(make_object()); + auto root_traced = MakeTraced(root); + auto obj_attr = root_traced.GetAttr(&ObjectWithAttrsNode::obj_attr); + static_assert(std::is_same>::value); + ICHECK(obj_attr.GetPath()->PathsEqual(ObjectPath::Root()->Attr("obj_attr"))); + ICHECK_EQ(obj_attr.Get().get(), root->obj_attr.get()); +} + +TEST(TracedObjectTest, TracedObject_GetAttr_Map) { + ObjectWithAttrs root(make_object()); + root->map_attr.Set("foo", "bar"); + + auto root_traced = MakeTraced(root); + auto map_attr = root_traced.GetAttr(&ObjectWithAttrsNode::map_attr); + static_assert(std::is_same>::value); + ICHECK(map_attr.GetPath()->PathsEqual(ObjectPath::Root()->Attr("map_attr"))); + ICHECK_EQ(map_attr.Get().get(), root->map_attr.get()); + + auto map_val = map_attr.at("foo"); + ICHECK_EQ(map_val.Get(), "bar"); + ICHECK( + map_val.GetPath()->PathsEqual(ObjectPath::Root()->Attr("map_attr")->MapValue(String("foo")))); +} + +TEST(TracedObjectTest, TracedObject_GetAttr_Array) { + ObjectWithAttrs root(make_object()); + root->array_attr.push_back("foo"); + root->array_attr.push_back("bar"); + + auto root_traced = MakeTraced(root); + auto array_attr = root_traced.GetAttr(&ObjectWithAttrsNode::array_attr); + static_assert(std::is_same>::value); + ICHECK(array_attr.GetPath()->PathsEqual(ObjectPath::Root()->Attr("array_attr"))); + ICHECK_EQ(array_attr.Get().get(), root->array_attr.get()); + + auto array_val = array_attr[1]; + ICHECK_EQ(array_val.Get(), "bar"); + ICHECK(array_val.GetPath()->PathsEqual(ObjectPath::Root()->Attr("array_attr")->ArrayIndex(1))); +} + +TEST(TracedObjectTest, TracedObject_GetAttr_Int64) { + ObjectWithAttrs root(make_object()); + auto root_traced = MakeTraced(root); + + auto int64_attr = root_traced.GetAttr(&ObjectWithAttrsNode::int64_attr); + static_assert(std::is_same>::value); + ICHECK_EQ(int64_attr.Get(), 5); + ICHECK(int64_attr.GetPath()->PathsEqual(ObjectPath::Root()->Attr("int64_attr"))); +} + +TEST(TracedObjectTest, TracedObject_IsInstance) { + ObjectRef dummy(make_object()); + auto traced = MakeTraced(dummy); + ICHECK(traced.IsInstance()); + ICHECK(!traced.IsInstance()); +} + +TEST(TracedObjectTest, TracedObject_Downcast) { + ObjectRef root(make_object()); + auto traced = MakeTraced(root); + + auto as_dummy = traced.Downcast(); + static_assert(std::is_same>::value); + ICHECK_EQ(as_dummy.Get(), root); + + // Try downcasting to a wrong type + bool caught = false; + try { + traced.Downcast(); + } catch (std::exception& e) { + caught = strstr(e.what(), + "Downcast from TracedObjectTestDummyObject to TracedObjectTestObjectWithAttrs " + "failed") != nullptr; + } + ICHECK(caught); +} + +TEST(TracedObjectTest, TracedObject_TryDowncast) { + ObjectRef root(make_object()); + auto traced = MakeTraced(root); + + auto as_dummy = traced.TryDowncast(); + static_assert(std::is_same>::value); + ICHECK(as_dummy.defined()); + ICHECK_EQ(as_dummy.value().Get(), root); + + // Try downcasting to a wrong type + ICHECK(!traced.TryDowncast().defined()); +} + +TEST(TracedObjectTest, TracedMap_At) { + Map m({{"k1", "foo"}, {"k2", "bar"}}); + auto traced = MakeTraced(m); + + auto traced_foo = traced.at("k1"); + static_assert(std::is_same>::value); + ICHECK_EQ(traced_foo.Get(), "foo"); + ICHECK(traced_foo.GetPath()->PathsEqual(ObjectPath::Root()->MapValue(String("k1")))); +} + +TEST(TracedObjectTest, TracedMap_Iterator) { + Map m({{"k1", "foo"}, {"k2", "bar"}}); + auto traced = MakeTraced(m); + + size_t k1_count = 0; + size_t k2_count = 0; + + for (const auto& kv : traced) { + if (kv.first == "k1") { + ++k1_count; + ICHECK_EQ(kv.second.Get(), "foo"); + ICHECK(kv.second.GetPath()->PathsEqual(ObjectPath::Root()->MapValue(String("k1")))); + } else if (kv.first == "k2") { + ++k2_count; + ICHECK_EQ(kv.second.Get(), "bar"); + ICHECK(kv.second.GetPath()->PathsEqual(ObjectPath::Root()->MapValue(String("k2")))); + } else { + ICHECK(false); + } + } + + ICHECK_EQ(k1_count, 1); + ICHECK_EQ(k2_count, 1); +} + +TEST(TracedObjectTest, TracedArray_Index) { + Array a = {"foo", "bar"}; + auto traced = MakeTraced(a); + + auto traced_bar = traced[1]; + static_assert(std::is_same>::value); + ICHECK_EQ(traced_bar.Get(), "bar"); + ICHECK(traced_bar.GetPath()->PathsEqual(ObjectPath::Root()->ArrayIndex(1))); +} + +TEST(TracedObjectTest, TracedArray_Iterator) { + Array a = {"foo", "bar"}; + auto traced = MakeTraced(a); + + size_t index = 0; + for (const auto& x : traced) { + if (index == 0) { + ICHECK_EQ(x.Get(), "foo"); + ICHECK(x.GetPath()->PathsEqual(ObjectPath::Root()->ArrayIndex(0))); + } else if (index == 1) { + ICHECK_EQ(x.Get(), "bar"); + ICHECK(x.GetPath()->PathsEqual(ObjectPath::Root()->ArrayIndex(1))); + } else { + ICHECK(false); + } + ++index; + } + + ICHECK_EQ(index, 2); +} + +TEST(TracedObjectTest, TracedBasicValue_ApplyFunc) { + auto traced = MakeTraced(123, ObjectPath::Root()->Attr("foo")); + static_assert(std::is_same>::value); + + auto transformed = traced.ApplyFunc([](int x) { return x + 4.0; }); + static_assert(std::is_same>::value); + + ICHECK(transformed.GetPath()->PathsEqual(ObjectPath::Root()->Attr("foo"))); +} From 8ce161c20e48b3ab33b78a18569a63a549e52869 Mon Sep 17 00:00:00 2001 From: Greg Bonik Date: Thu, 4 Aug 2022 11:06:42 -0700 Subject: [PATCH 2/2] Move traced_object.h to public headers & add doc comments. --- .../tvm}/script/printer/traced_object.h | 76 ++++++++++++++++++- tests/cpp/traced_object_test.cc | 2 +- 2 files changed, 75 insertions(+), 3 deletions(-) rename {src => include/tvm}/script/printer/traced_object.h (86%) diff --git a/src/script/printer/traced_object.h b/include/tvm/script/printer/traced_object.h similarity index 86% rename from src/script/printer/traced_object.h rename to include/tvm/script/printer/traced_object.h index aa07bf7284fc..6f04b66cec97 100644 --- a/src/script/printer/traced_object.h +++ b/include/tvm/script/printer/traced_object.h @@ -159,6 +159,9 @@ class TracedObject { ObjectPath path_; }; +/*! + * \brief Iterator class for TracedMap + */ template class TracedMapIterator { public: @@ -215,6 +218,9 @@ class TracedMap { explicit TracedMap(Map map, ObjectPath path) : map_(std::move(map)), path_(std::move(path)) {} + /*! + * \brief Get a value by its key, wrapped in a traced wrapper. + */ WrappedV at(const K& key) const { auto it = map_.find(key); ICHECK(it != map_.end()) << "No such key in Map"; @@ -222,14 +228,29 @@ class TracedMap { return WrappedV(kv.second, path_->MapValue(kv.first)); } + /*! + * \brief Access the wrapped map object. + */ const Map& Get() const { return map_; } + /*! + * \brief Get the path of the wrapped object. + */ const ObjectPath& GetPath() const { return path_; } + /*! + * \brief Get an iterator to the first item of the map. + */ iterator begin() const { return iterator(map_.begin(), path_); } + /*! + * \brief Get an iterator to the end of the map. + */ iterator end() const { return iterator(map_.end(), path_); } + /*! + * \brief Returns true iff the wrapped map is empty. + */ bool empty() const { return map_.empty(); } private: @@ -237,6 +258,9 @@ class TracedMap { ObjectPath path_; }; +/*! + * \brief Iterator class for TracedArray + */ template class TracedArrayIterator { public: @@ -286,8 +310,6 @@ class TracedArrayIterator { bool operator!=(TracedArrayIterator other) const { return !(*this == other); } value_type operator*() const { return WrappedT(array_[index_], array_path_->ArrayIndex(index_)); } - bool empty() const { return array_.empty(); } - private: Array array_; size_t index_; @@ -308,20 +330,45 @@ class TracedArray { explicit TracedArray(Array array, ObjectPath path) : array_(std::move(array)), path_(std::move(path)) {} + /*! + * \brief Access the wrapped array object. + */ const Array& Get() const { return array_; } + /*! + * \brief Get the path of the wrapped array object. + */ const ObjectPath& GetPath() const { return path_; } + /*! + * \brief Get an element by index, wrapped in a traced wrapper. + */ WrappedT operator[](size_t index) const { return WrappedT(array_[index], path_->ArrayIndex(index)); } + /*! + * \brief Get an iterator to the first array element. + * + * The iterator's dereference operator will automatically wrap each element in a traced wrapper. + */ iterator begin() const { return iterator(array_, 0, path_); } + /*! + * \brief Get an iterator to the end of the array. + * + * The iterator's dereference operator will automatically wrap each element in a traced wrapper. + */ iterator end() const { return iterator(array_, array_.size(), path_); } + /*! + * \brief Returns true iff the wrapped array is empty. + */ bool empty() const { return array_.empty(); } + /*! + * \brief Get the size of the wrapped array. + */ size_t size() const { return array_.size(); } private: @@ -337,21 +384,40 @@ class TracedOptional { public: using WrappedT = typename detail::TracedObjectWrapperSelector::Type; + /*! + * \brief Implicit conversion from the corresponding non-optional traced wrapper. + */ TracedOptional(const WrappedT& value) // NOLINT(runtime/explicit) : optional_(value.Get().defined() ? value.Get() : Optional(NullOpt)), path_(value.GetPath()) {} + // Don't use this direcly. For convenience, call MakeTraced() instead. explicit TracedOptional(Optional optional, ObjectPath path) : optional_(std::move(optional)), path_(std::move(path)) {} + /*! + * \brief Access the wrapped optional object. + */ const Optional& Get() const { return optional_; } + /*! + * \brief Get the path of the wrapped optional object. + */ const ObjectPath& GetPath() const { return path_; } + /*! + * \brief Returns true iff the object is present. + */ bool defined() const { return optional_.defined(); } + /*! + * \brief Returns a non-optional traced wrapper, throws if defined() is false. + */ WrappedT value() const { return WrappedT(optional_.value(), path_); } + /*! + * \brief Same as defined(). + */ explicit operator bool() const { return optional_.defined(); } private: @@ -368,8 +434,14 @@ class TracedBasicValue { explicit TracedBasicValue(const T& value, ObjectPath path) : value_(value), path_(std::move(path)) {} + /*! + * \brief Access the wrapped value. + */ const T& Get() const { return value_; } + /*! + * \brief Get the path of the wrapped value. + */ const ObjectPath& GetPath() const { return path_; } /*! diff --git a/tests/cpp/traced_object_test.cc b/tests/cpp/traced_object_test.cc index 1e07919dc259..7890a67eef95 100644 --- a/tests/cpp/traced_object_test.cc +++ b/tests/cpp/traced_object_test.cc @@ -17,11 +17,11 @@ * under the License. */ -#include <../../src/script/printer/traced_object.h> #include #include #include #include +#include using namespace tvm;