From 9d6f8e4e1384a9c993e1b8914ad0fb4bc856547e Mon Sep 17 00:00:00 2001 From: Greg Bonik Date: Thu, 30 Jun 2022 13:36:54 -0700 Subject: [PATCH 1/5] [TVMScript] Add ObjectPath class --- include/tvm/node/object_path.h | 281 +++++++++++++++++++ python/tvm/runtime/object_path.py | 122 ++++++++ src/node/object_path.cc | 322 ++++++++++++++++++++++ tests/python/unittest/test_object_path.py | 157 +++++++++++ 4 files changed, 882 insertions(+) create mode 100644 include/tvm/node/object_path.h create mode 100644 python/tvm/runtime/object_path.py create mode 100644 src/node/object_path.cc create mode 100644 tests/python/unittest/test_object_path.py diff --git a/include/tvm/node/object_path.h b/include/tvm/node/object_path.h new file mode 100644 index 000000000000..43f272c0f74b --- /dev/null +++ b/include/tvm/node/object_path.h @@ -0,0 +1,281 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/node/object_path.h + * ObjectPath class that represents a path from a root object to one of its descendants + * via attribute access, array indexing etc. + */ + +#ifndef TVM_NODE_OBJECT_PATH_H_ +#define TVM_NODE_OBJECT_PATH_H_ + +#include +#include + +#include + +namespace tvm { + +using runtime::Object; +using runtime::ObjectPtr; +using runtime::ObjectRef; + +class ObjectPath; + +/*! + * \brief Path to an object from some root object. + * + * Motivation: + * + * Same IR node object can be referenced in several different contexts inside a larger IR object. + * For example, a variable could be referenced in several statements within a block. + * + * This makes it impossible to use an object pointer to uniquely identify a "location" within + * the larger IR object for error reporting purposes. The ObjectPath class addresses this problem + * by serving as a unique "locator". + */ +class ObjectPathNode : public Object { + public: + /*! \brief Get the parent path */ + ObjectPath GetParent() const; + /*! + * \brief Get the length of the path. + * + * For example, the path returned by `ObjectPath::Root()` has length 1. + */ + size_t Length() const; + + /*! + * \brief Get a path prefix of the given length. + * + * Provided `length` must not exceed the `Length()` of this path. + */ + ObjectPath GetPrefix(size_t length) const; + + /*! + * \brief Check if this path is a prefix of another path. + * + * The prefix is not strict, i.e. a path is considered a prefix of itself. + */ + bool IsPrefixOf(const ObjectPath& other) const; + + /*! \brief Check if two paths are equal. */ + bool PathsEqual(const ObjectPath& other) const; + + /*! \brief Extend this path with access to an object attribute. */ + ObjectPath Attr(const char* attr_key); + + /*! \brief Extend this path with access to an object attribute. */ + ObjectPath Attr(String attr_key); + + /*! \brief Extend this path with access to an array element. */ + ObjectPath ArrayIndex(size_t index); + + /*! \brief Extend this path with access to a missing array element. */ + ObjectPath MissingArrayElement(size_t index); + + /*! \brief Extend this path with access to a map value. */ + ObjectPath MapValue(ObjectRef key); + + /*! \brief Extend this path with access to a missing map entry. */ + ObjectPath MissingMapEntry(); + + static constexpr const char* _type_key = "ObjectPath"; + TVM_DECLARE_BASE_OBJECT_INFO(ObjectPathNode, Object); + + protected: + explicit ObjectPathNode(ObjectPathNode* parent); + + friend class ObjectPath; + friend std::string GetObjectPathRepr(const ObjectPathNode* node); + + const ObjectPathNode* ParentNode() const; + + /*! Compares just the last node of the path, without comparing the whole path. */ + virtual bool LastNodeEqual(const ObjectPathNode* other) const = 0; + + virtual std::string LastNodeString() const = 0; + + private: + ObjectRef parent_; + size_t length_; +}; + +class ObjectPath : public ObjectRef { + public: + /*! \brief Create a path that represents the root object itself. */ + static ObjectPath Root(); + + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(ObjectPath, ObjectRef, ObjectPathNode); +}; + +//------------------------------------------------------------------------- +//----- Concrete object path nodes ------------------------------------ +//------------------------------------------------------------------------- + +// ----- Root ----- + +class RootPathNode final : public ObjectPathNode { + public: + explicit RootPathNode(); + + static constexpr const char* _type_key = "RootPath"; + TVM_DECLARE_FINAL_OBJECT_INFO(RootPathNode, ObjectPathNode); + + protected: + bool LastNodeEqual(const ObjectPathNode* other) const final; + std::string LastNodeString() const final; +}; + +class RootPath : public ObjectPath { + public: + TVM_DEFINE_OBJECT_REF_METHODS(RootPath, ObjectPath, RootPathNode); +}; + +// ----- Attribute access ----- + +class AttributeAccessPathNode final : public ObjectPathNode { + public: + /*! \brief Name of the attribute being accessed. Must be a static string. */ + String attr_key; + + explicit AttributeAccessPathNode(ObjectPathNode* parent, String attr_key); + + static constexpr const char* _type_key = "AttributeAccessPath"; + TVM_DECLARE_FINAL_OBJECT_INFO(AttributeAccessPathNode, ObjectPathNode); + + protected: + bool LastNodeEqual(const ObjectPathNode* other) const final; + std::string LastNodeString() const final; +}; + +class AttributeAccessPath : public ObjectPath { + public: + TVM_DEFINE_OBJECT_REF_METHODS(AttributeAccessPath, ObjectPath, AttributeAccessPathNode); +}; + +// ----- Unknown attribute access ----- + +class UnknownAttributeAccessPathNode final : public ObjectPathNode { + public: + explicit UnknownAttributeAccessPathNode(ObjectPathNode* parent); + + static constexpr const char* _type_key = "UnknownAttributeAccessPath"; + TVM_DECLARE_FINAL_OBJECT_INFO(UnknownAttributeAccessPathNode, ObjectPathNode); + + protected: + bool LastNodeEqual(const ObjectPathNode* other) const final; + std::string LastNodeString() const final; +}; + +class UnknownAttributeAccessPath : public ObjectPath { + public: + TVM_DEFINE_OBJECT_REF_METHODS(UnknownAttributeAccessPath, ObjectPath, + UnknownAttributeAccessPathNode); +}; + +// ----- Array element access by index ----- + +class ArrayIndexPathNode : public ObjectPathNode { + public: + /*! \brief Index of the array element that is being accessed. */ + size_t index; + + explicit ArrayIndexPathNode(ObjectPathNode* parent, size_t index); + + static constexpr const char* _type_key = "ArrayIndexPath"; + TVM_DECLARE_FINAL_OBJECT_INFO(ArrayIndexPathNode, ObjectPathNode); + + protected: + bool LastNodeEqual(const ObjectPathNode* other) const final; + std::string LastNodeString() const final; +}; + +class ArrayIndexPath : public ObjectPath { + public: + TVM_DEFINE_OBJECT_REF_METHODS(ArrayIndexPath, ObjectPath, ArrayIndexPathNode); +}; + +// ----- Missing array element ----- + +class MissingArrayElementPathNode : public ObjectPathNode { + public: + /*! \brief Index of the array element that is missing. */ + size_t index; + + explicit MissingArrayElementPathNode(ObjectPathNode* parent, size_t index); + + static constexpr const char* _type_key = "MissingArrayElementPath"; + TVM_DECLARE_FINAL_OBJECT_INFO(MissingArrayElementPathNode, ObjectPathNode); + + protected: + bool LastNodeEqual(const ObjectPathNode* other) const final; + std::string LastNodeString() const final; +}; + +class MissingArrayElementPath : public ObjectPath { + public: + TVM_DEFINE_OBJECT_REF_METHODS(MissingArrayElementPath, ObjectPath, MissingArrayElementPathNode); +}; + +// ----- Map value ----- + +class MapValuePathNode : public ObjectPathNode { + public: + /*! \brief Key of the map entry that is being accessed */ + ObjectRef key; + + explicit MapValuePathNode(ObjectPathNode* parent, ObjectRef key); + + static constexpr const char* _type_key = "MapValuePath"; + TVM_DECLARE_FINAL_OBJECT_INFO(MapValuePathNode, ObjectPathNode); + + protected: + bool LastNodeEqual(const ObjectPathNode* other) const final; + std::string LastNodeString() const final; +}; + +class MapValuePath : public ObjectPath { + public: + TVM_DEFINE_OBJECT_REF_METHODS(MapValuePath, ObjectPath, MapValuePathNode); +}; + +// ----- Missing map entry ----- + +class MissingMapEntryPathNode : public ObjectPathNode { + public: + explicit MissingMapEntryPathNode(ObjectPathNode* parent); + + static constexpr const char* _type_key = "MissingMapEntryPath"; + TVM_DECLARE_FINAL_OBJECT_INFO(MissingMapEntryPathNode, ObjectPathNode); + + protected: + bool LastNodeEqual(const ObjectPathNode* other) const final; + std::string LastNodeString() const final; +}; + +class MissingMapEntryPath : public ObjectPath { + public: + TVM_DEFINE_OBJECT_REF_METHODS(MissingMapEntryPath, ObjectPath, MissingMapEntryPathNode); +}; + +} // namespace tvm + +#endif // TVM_NODE_OBJECT_PATH_H_ diff --git a/python/tvm/runtime/object_path.py b/python/tvm/runtime/object_path.py new file mode 100644 index 000000000000..938e14ab4c9e --- /dev/null +++ b/python/tvm/runtime/object_path.py @@ -0,0 +1,122 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm._ffi +from . import _ffi_node_api +from tvm.runtime import Object + + +__all__ = ( + "ObjectPath", + "RootPath", + "AttributeAccessPath", + "UnknownAttributeAccessPath", + "ArrayIndexPath", + "MissingArrayElementPath", + "MapValuePath", + "MissingMapEntryPath", +) + + +@tvm._ffi.register_object("ObjectPath") +class ObjectPath(Object): + def __init__(self) -> None: + super().__init__() + raise ValueError( + "ObjectPath can't be initialized directly. " + "Use ObjectPath.root() to create a path to the root object" + ) + + @staticmethod + def root() -> "ObjectPath": + return _ffi_node_api.ObjectPathRoot() + + def __eq__(self, other): + return _ffi_node_api.ObjectPathEqual(self, other) + + def __ne__(self, other): + return not _ffi_node_api.ObjectPathEqual(self, other) + + @property + def parent(self) -> "ObjectPath": + return _ffi_node_api.ObjectPathGetParent(self) + + def __len__(self) -> int: + return _ffi_node_api.ObjectPathLength(self) + + def __getitem__(self, slc) -> "ObjectPath": + if not isinstance(slc, slice): + raise TypeError("ObjectPath can only be indexed with a slice") + if slc.start is not None or slc.step is not None: + raise ValueError( + "ObjectPath can only be indexed with a slice" " of the form [:prefix_len]" + ) + prefix_len = slc.stop if slc.stop >= 0 else len(self) + slc.stop + return _ffi_node_api.ObjectPathGetPrefix(self, prefix_len) + + def is_prefix_of(self, other) -> "ObjectPath": + return _ffi_node_api.ObjectPathIsPrefixOf(self, other) + + def attr(self, attr_key) -> "ObjectPath": + return _ffi_node_api.ObjectPathAttr(self, attr_key) + + def array_index(self, index) -> "ObjectPath": + return _ffi_node_api.ObjectPathArrayIndex(self, index) + + def missing_array_element(self, index) -> "ObjectPath": + return _ffi_node_api.ObjectPathMissingArrayElement(self, index) + + def map_value(self, key) -> "ObjectPath": + return _ffi_node_api.ObjectPathMapValue(self, tvm.runtime.convert(key)) + + def missing_map_entry(self) -> "ObjectPath": + return _ffi_node_api.ObjectPathMissingMapEntry(self) + + +@tvm._ffi.register_object("RootPath") +class RootPath(ObjectPath): + pass + + +@tvm._ffi.register_object("AttributeAccessPath") +class AttributeAccessPath(ObjectPath): + pass + + +@tvm._ffi.register_object("UnknownAttributeAccessPath") +class UnknownAttributeAccessPath(ObjectPath): + pass + + +@tvm._ffi.register_object("ArrayIndexPath") +class ArrayIndexPath(ObjectPath): + pass + + +@tvm._ffi.register_object("MissingArrayElementPath") +class MissingArrayElementPath(ObjectPath): + pass + + +@tvm._ffi.register_object("MapValuePath") +class MapValuePath(ObjectPath): + pass + + +@tvm._ffi.register_object("MissingMapEntryPath") +class MissingMapEntryPath(ObjectPath): + pass diff --git a/src/node/object_path.cc b/src/node/object_path.cc new file mode 100644 index 000000000000..4596f139c116 --- /dev/null +++ b/src/node/object_path.cc @@ -0,0 +1,322 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include + +#include +#include + +using namespace tvm::runtime; + +namespace tvm { + +// ============== ObjectPathNode ============== + +ObjectPathNode::ObjectPathNode(ObjectPathNode* parent) + : parent_(GetRef(parent)), length_(parent == nullptr ? 1 : parent->length_ + 1) {} + +// --- GetParent --- + +ObjectPath ObjectPathNode::GetParent() const { return Downcast(parent_); } + +TVM_REGISTER_GLOBAL("node.ObjectPathGetParent").set_body_typed([](const ObjectPath& path) { + return path->GetParent(); +}); + +// --- Length --- + +size_t ObjectPathNode::Length() const { return length_; } + +TVM_REGISTER_GLOBAL("node.ObjectPathLength").set_body_typed([](const ObjectPath& path) { + return static_cast(path->Length()); +}); + +// --- GetPrefix --- + +ObjectPath ObjectPathNode::GetPrefix(size_t length) const { + if (length > Length()) { + throw std::out_of_range("Attempted to get a prefix longer than the path itself"); + } + + const ObjectPathNode* node = this; + size_t suffix_len = Length() - length; + for (size_t i = 0; i < suffix_len; ++i) { + node = node->ParentNode(); + } + + return GetRef(node); +} + +TVM_REGISTER_GLOBAL("node.ObjectPathGetPrefix") + .set_body_typed([](const ObjectPath& path, int64_t length) { + if (length < 0) { + throw std::out_of_range("Prefix length can't be negative"); + } + return path->GetPrefix(static_cast(length)); + }); + +// --- IsPrefixOf --- + +bool ObjectPathNode::IsPrefixOf(const ObjectPath& other) const { + if (!other.defined()) { + return false; + } + + size_t this_len = Length(); + if (this_len > other->Length()) { + return false; + } + return this->PathsEqual(other->GetPrefix(this_len)); +} + +TVM_REGISTER_GLOBAL("node.ObjectPathIsPrefixOf") + .set_body_typed([](const ObjectPath& a, const ObjectPath& b) { return a->IsPrefixOf(b); }); + +// --- Attr --- + +ObjectPath ObjectPathNode::Attr(const char* attr_key) { + if (attr_key != nullptr) { + return ObjectPath(make_object(this, attr_key)); + } else { + return ObjectPath(make_object(this)); + } +} + +ObjectPath ObjectPathNode::Attr(String attr_key) { + if (attr_key.defined()) { + return ObjectPath(make_object(this, attr_key)); + } else { + return ObjectPath(make_object(this)); + } +} + +TVM_REGISTER_GLOBAL("node.ObjectPathAttr") + .set_body_typed([](const ObjectPath& path, Optional attr_key) { + return path->Attr(attr_key.defined() ? attr_key.value() : String(nullptr)); + }); + +// --- ArrayIndex --- + +ObjectPath ObjectPathNode::ArrayIndex(size_t index) { + return ObjectPath(make_object(this, index)); +} + +TVM_REGISTER_GLOBAL("node.ObjectPathArrayIndex") + .set_body_typed([](const ObjectPath& path, size_t index) { return path->ArrayIndex(index); }); + +// --- MissingArrayElement --- + +ObjectPath ObjectPathNode::MissingArrayElement(size_t index) { + return ObjectPath(make_object(this, index)); +} + +TVM_REGISTER_GLOBAL("node.ObjectPathMissingArrayElement") + .set_body_typed([](const ObjectPath& path, size_t index) { + return path->MissingArrayElement(index); + }); + +// --- MapValue --- + +ObjectPath ObjectPathNode::MapValue(ObjectRef key) { + return ObjectPath(make_object(this, std::move(key))); +} + +TVM_REGISTER_GLOBAL("node.ObjectPathMapValue") + .set_body_typed([](const ObjectPath& path, const ObjectRef& key) { + return path->MapValue(key); + }); + +// --- MissingMapEntry --- + +ObjectPath ObjectPathNode::MissingMapEntry() { + return ObjectPath(make_object(this)); +} + +TVM_REGISTER_GLOBAL("node.ObjectPathMissingMapEntry").set_body_typed([](const ObjectPath& path) { + return path->MissingMapEntry(); +}); + +// --- PathsEqual ---- + +bool ObjectPathNode::PathsEqual(const ObjectPath& other) const { + if (!other.defined() || Length() != other->Length()) { + return false; + } + + const ObjectPathNode* lhs = this; + const ObjectPathNode* rhs = static_cast(other.get()); + + while (lhs != nullptr && rhs != nullptr) { + if (lhs->type_index() != rhs->type_index()) { + return false; + } + if (!lhs->LastNodeEqual(rhs)) { + return false; + } + lhs = lhs->ParentNode(); + rhs = rhs->ParentNode(); + } + + return lhs == nullptr && rhs == nullptr; +} + +TVM_REGISTER_GLOBAL("node.ObjectPathEqual") + .set_body_typed([](const ObjectPath& lhs, const ObjectPath& rhs) { + return lhs->PathsEqual(rhs); + }); + +// --- Repr --- + +std::string GetObjectPathRepr(const ObjectPathNode* node) { + std::string ret; + while (node != nullptr) { + std::string node_str = node->LastNodeString(); + ret.append(node_str.rbegin(), node_str.rend()); + node = static_cast(node->GetParent().get()); + } + std::reverse(ret.begin(), ret.end()); + return ret; +} + +static void PrintObjectPathRepr(const ObjectRef& node, ReprPrinter* p) { + p->stream << GetObjectPathRepr(static_cast(node.get())); +} + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(PrintObjectPathRepr); + +// --- Private/protected methods --- + +const ObjectPathNode* ObjectPathNode::ParentNode() const { + return static_cast(parent_.get()); +} + +// ============== ObjectPath ============== + +/* static */ ObjectPath ObjectPath::Root() { return ObjectPath(make_object()); } + +TVM_REGISTER_GLOBAL("node.ObjectPathRoot").set_body_typed([]() { return ObjectPath::Root(); }); + +// ============== Individual path classes ============== + +// ----- Root ----- + +RootPathNode::RootPathNode() : ObjectPathNode(nullptr) {} + +bool RootPathNode::LastNodeEqual(const ObjectPathNode* other) const { return true; } + +std::string RootPathNode::LastNodeString() const { return ""; } + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(PrintObjectPathRepr); + +// ----- AttributeAccess ----- + +AttributeAccessPathNode::AttributeAccessPathNode(ObjectPathNode* parent, String attr_key) + : ObjectPathNode(parent), attr_key(std::move(attr_key)) {} + +bool AttributeAccessPathNode::LastNodeEqual(const ObjectPathNode* other) const { + const auto* otherAttrAccess = static_cast(other); + return attr_key == otherAttrAccess->attr_key; +} + +std::string AttributeAccessPathNode::LastNodeString() const { return "." + attr_key; } + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch(PrintObjectPathRepr); + +// ----- UnknownAttributeAccess ----- + +UnknownAttributeAccessPathNode::UnknownAttributeAccessPathNode(ObjectPathNode* parent) + : ObjectPathNode(parent) {} + +bool UnknownAttributeAccessPathNode::LastNodeEqual(const ObjectPathNode* other) const { + // Consider any two unknown attribute accesses unequal + return false; +} + +std::string UnknownAttributeAccessPathNode::LastNodeString() const { + return "."; +} + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch(PrintObjectPathRepr); + +// ----- ArrayIndexPath ----- + +ArrayIndexPathNode::ArrayIndexPathNode(ObjectPathNode* parent, size_t index) + : ObjectPathNode(parent), index(index) {} + +bool ArrayIndexPathNode::LastNodeEqual(const ObjectPathNode* other) const { + const auto* otherArrayIndex = static_cast(other); + return index == otherArrayIndex->index; +} + +std::string ArrayIndexPathNode::LastNodeString() const { return "[" + std::to_string(index) + "]"; } + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(PrintObjectPathRepr); + +// ----- MissingArrayElement ----- + +MissingArrayElementPathNode::MissingArrayElementPathNode(ObjectPathNode* parent, size_t index) + : ObjectPathNode(parent), index(index) {} + +bool MissingArrayElementPathNode::LastNodeEqual(const ObjectPathNode* other) const { + const auto* otherMissingElement = static_cast(other); + return index == otherMissingElement->index; +} + +std::string MissingArrayElementPathNode::LastNodeString() const { + return "[]"; +} + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch(PrintObjectPathRepr); + +// ----- MapValue ----- + +MapValuePathNode::MapValuePathNode(ObjectPathNode* parent, ObjectRef key) + : ObjectPathNode(parent), key(std::move(key)) {} + +bool MapValuePathNode::LastNodeEqual(const ObjectPathNode* other) const { + const auto* otherMapValue = static_cast(other); + return ObjectEqual()(key, otherMapValue->key); +} + +std::string MapValuePathNode::LastNodeString() const { + std::ostringstream s; + s << "[" << key << "]"; + return s.str(); +} + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(PrintObjectPathRepr); + +// ----- MissingMapEntry ----- + +MissingMapEntryPathNode::MissingMapEntryPathNode(ObjectPathNode* parent) : ObjectPathNode(parent) {} + +bool MissingMapEntryPathNode::LastNodeEqual(const ObjectPathNode* other) const { return true; } + +std::string MissingMapEntryPathNode::LastNodeString() const { return "[]"; } + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch(PrintObjectPathRepr); + +} // namespace tvm diff --git a/tests/python/unittest/test_object_path.py b/tests/python/unittest/test_object_path.py new file mode 100644 index 000000000000..e48542fef702 --- /dev/null +++ b/tests/python/unittest/test_object_path.py @@ -0,0 +1,157 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest +import tvm +from tvm.runtime import object_path +from tvm.runtime.object_path import ObjectPath + + +def test_root_path(): + root = ObjectPath.root() + assert isinstance(root, object_path.RootPath) + assert str(root) == "" + assert len(root) == 1 + assert root == ObjectPath.root() + assert root.parent is None + + +def test_path_attr(): + path = ObjectPath.root().attr("foo") + assert isinstance(path, object_path.AttributeAccessPath) + assert str(path) == ".foo" + assert len(path) == 2 + assert path.parent == ObjectPath.root() + + +def test_path_attr_unknown(): + path = ObjectPath.root().attr(None) + assert isinstance(path, object_path.UnknownAttributeAccessPath) + assert str(path) == "." + assert len(path) == 2 + assert path.parent == ObjectPath.root() + + +def test_path_array_index(): + path = ObjectPath.root().array_index(2) + assert isinstance(path, object_path.ArrayIndexPath) + assert str(path) == "[2]" + assert len(path) == 2 + assert path.parent == ObjectPath.root() + + +def test_path_missing_array_element(): + path = ObjectPath.root().missing_array_element(2) + assert isinstance(path, object_path.MissingArrayElementPath) + assert str(path) == "[]" + assert len(path) == 2 + assert path.parent == ObjectPath.root() + + +def test_path_map_value(): + path = ObjectPath.root().map_value("foo") + assert isinstance(path, object_path.MapValuePath) + assert str(path) == '["foo"]' + assert len(path) == 2 + assert path.parent == ObjectPath.root() + + +def test_path_missing_map_entry(): + path = ObjectPath.root().missing_map_entry() + assert isinstance(path, object_path.MissingMapEntryPath) + assert str(path) == "[]" + assert len(path) == 2 + assert path.parent == ObjectPath.root() + + +@pytest.mark.parametrize( + "a, b, expected", + [ + (ObjectPath.root(), ObjectPath.root(), True), + (ObjectPath.root(), None, False), + (ObjectPath.root(), ObjectPath.root().attr("foo"), True), + (ObjectPath.root().attr("foo"), ObjectPath.root(), False), + (ObjectPath.root().attr("foo"), ObjectPath.root().attr("foo"), True), + (ObjectPath.root().attr("bar"), ObjectPath.root().attr("foo"), False), + (ObjectPath.root().attr("foo"), ObjectPath.root().attr("foo").array_index(2), True), + (ObjectPath.root().attr("foo").array_index(2), ObjectPath.root().attr("foo"), False), + (ObjectPath.root().attr("foo"), ObjectPath.root().attr("bar").array_index(2), False), + ], +) +def test_path_is_prefix_of(a, b, expected): + assert a.is_prefix_of(b) == expected + + +paths_for_equality_test = [ + ObjectPath.root(), + ObjectPath.root().attr("foo"), + ObjectPath.root().attr("bar"), + ObjectPath.root().array_index(3), + ObjectPath.root().array_index(4), + ObjectPath.root().missing_array_element(3), + ObjectPath.root().missing_array_element(4), + ObjectPath.root().map_value("foo"), + ObjectPath.root().map_value("bar"), + ObjectPath.root().missing_map_entry(), + ObjectPath.root().attr("foo").missing_map_entry(), +] + + +def make_test_params_for_eq_test(): + return [ + pytest.param(idx, path, id="path{}".format(idx)) + for idx, path in enumerate(paths_for_equality_test) + ] + + +@pytest.mark.parametrize("a_idx, a_path", make_test_params_for_eq_test()) +@pytest.mark.parametrize("b_idx, b_path", make_test_params_for_eq_test()) +def test_path_equal(a_idx, a_path, b_idx, b_path): + expected = a_idx == b_idx + result = a_path == b_path + assert result == expected + + +def test_path_get_prefix(): + p1 = ObjectPath.root() + p2 = p1.attr("foo") + p3 = p2.array_index(5) + + assert p3.parent == p2 + assert p2.parent == p1 + assert p1.parent is None + + assert p2[:0] is None + assert p2[:1] == p1 + assert p2[:-2] is None + assert p2[:-1] == p1 + + assert p3[:0] is None + assert p3[:1] == p1 + assert p3[:2] == p2 + assert p3[:3] == p3 + assert p3[:-3] is None + assert p3[:-2] == p1 + assert p3[:-1] == p2 + + with pytest.raises(tvm._ffi.base.TVMError) as e: + p3[:-4] + assert "Prefix length can't be negative" in str(e.value) + + with pytest.raises(tvm._ffi.base.TVMError) as e: + p3[:4] + assert "Attempted to get a prefix longer than the path itself" in str(e.value) From 34487ed98e51d050311b2545f6c703d490f12262 Mon Sep 17 00:00:00 2001 From: Greg Bonik Date: Fri, 1 Jul 2022 11:51:30 -0700 Subject: [PATCH 2/5] Address pylint errors --- python/tvm/runtime/object_path.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/python/tvm/runtime/object_path.py b/python/tvm/runtime/object_path.py index 938e14ab4c9e..3c3798cf03fc 100644 --- a/python/tvm/runtime/object_path.py +++ b/python/tvm/runtime/object_path.py @@ -15,9 +15,14 @@ # specific language governing permissions and limitations # under the License. +""" +ObjectPath class that represents a path from a root object to one of its descendants +via attribute access, array indexing etc. +""" + import tvm._ffi -from . import _ffi_node_api from tvm.runtime import Object +from . import _ffi_node_api __all__ = ( @@ -34,6 +39,10 @@ @tvm._ffi.register_object("ObjectPath") class ObjectPath(Object): + """ + Path to an object from some root object. + """ + def __init__(self) -> None: super().__init__() raise ValueError( From 25d22ea20636be8731de8547f5ba9d0d35947179 Mon Sep 17 00:00:00 2001 From: Greg Bonik Date: Mon, 11 Jul 2022 11:44:17 -0700 Subject: [PATCH 3/5] Address Junru's comments --- include/tvm/node/object_path.h | 43 +++++++------- src/node/object_path.cc | 70 +++++++++++------------ tests/python/unittest/test_object_path.py | 11 +--- 3 files changed, 59 insertions(+), 65 deletions(-) diff --git a/include/tvm/node/object_path.h b/include/tvm/node/object_path.h index 43f272c0f74b..5175c5b0c40d 100644 --- a/include/tvm/node/object_path.h +++ b/include/tvm/node/object_path.h @@ -26,6 +26,7 @@ #ifndef TVM_NODE_OBJECT_PATH_H_ #define TVM_NODE_OBJECT_PATH_H_ +#include #include #include @@ -54,20 +55,20 @@ class ObjectPath; class ObjectPathNode : public Object { public: /*! \brief Get the parent path */ - ObjectPath GetParent() const; + Optional GetParent() const; /*! * \brief Get the length of the path. * * For example, the path returned by `ObjectPath::Root()` has length 1. */ - size_t Length() const; + int32_t Length() const; /*! * \brief Get a path prefix of the given length. * * Provided `length` must not exceed the `Length()` of this path. */ - ObjectPath GetPrefix(size_t length) const; + ObjectPath GetPrefix(int32_t length) const; /*! * \brief Check if this path is a prefix of another path. @@ -80,28 +81,28 @@ class ObjectPathNode : public Object { bool PathsEqual(const ObjectPath& other) const; /*! \brief Extend this path with access to an object attribute. */ - ObjectPath Attr(const char* attr_key); + ObjectPath Attr(const char* attr_key) const; /*! \brief Extend this path with access to an object attribute. */ - ObjectPath Attr(String attr_key); + ObjectPath Attr(Optional attr_key) const; /*! \brief Extend this path with access to an array element. */ - ObjectPath ArrayIndex(size_t index); + ObjectPath ArrayIndex(int32_t index) const; /*! \brief Extend this path with access to a missing array element. */ - ObjectPath MissingArrayElement(size_t index); + ObjectPath MissingArrayElement(int32_t index) const; /*! \brief Extend this path with access to a map value. */ - ObjectPath MapValue(ObjectRef key); + ObjectPath MapValue(ObjectRef key) const; /*! \brief Extend this path with access to a missing map entry. */ - ObjectPath MissingMapEntry(); + ObjectPath MissingMapEntry() const; static constexpr const char* _type_key = "ObjectPath"; TVM_DECLARE_BASE_OBJECT_INFO(ObjectPathNode, Object); protected: - explicit ObjectPathNode(ObjectPathNode* parent); + explicit ObjectPathNode(const ObjectPathNode* parent); friend class ObjectPath; friend std::string GetObjectPathRepr(const ObjectPathNode* node); @@ -114,8 +115,8 @@ class ObjectPathNode : public Object { virtual std::string LastNodeString() const = 0; private: - ObjectRef parent_; - size_t length_; + Optional parent_; + int32_t length_; }; class ObjectPath : public ObjectRef { @@ -123,7 +124,7 @@ class ObjectPath : public ObjectRef { /*! \brief Create a path that represents the root object itself. */ static ObjectPath Root(); - TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(ObjectPath, ObjectRef, ObjectPathNode); + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ObjectPath, ObjectRef, ObjectPathNode); }; //------------------------------------------------------------------------- @@ -156,7 +157,7 @@ class AttributeAccessPathNode final : public ObjectPathNode { /*! \brief Name of the attribute being accessed. Must be a static string. */ String attr_key; - explicit AttributeAccessPathNode(ObjectPathNode* parent, String attr_key); + explicit AttributeAccessPathNode(const ObjectPathNode* parent, String attr_key); static constexpr const char* _type_key = "AttributeAccessPath"; TVM_DECLARE_FINAL_OBJECT_INFO(AttributeAccessPathNode, ObjectPathNode); @@ -175,7 +176,7 @@ class AttributeAccessPath : public ObjectPath { class UnknownAttributeAccessPathNode final : public ObjectPathNode { public: - explicit UnknownAttributeAccessPathNode(ObjectPathNode* parent); + explicit UnknownAttributeAccessPathNode(const ObjectPathNode* parent); static constexpr const char* _type_key = "UnknownAttributeAccessPath"; TVM_DECLARE_FINAL_OBJECT_INFO(UnknownAttributeAccessPathNode, ObjectPathNode); @@ -196,9 +197,9 @@ class UnknownAttributeAccessPath : public ObjectPath { class ArrayIndexPathNode : public ObjectPathNode { public: /*! \brief Index of the array element that is being accessed. */ - size_t index; + int32_t index; - explicit ArrayIndexPathNode(ObjectPathNode* parent, size_t index); + explicit ArrayIndexPathNode(const ObjectPathNode* parent, int32_t index); static constexpr const char* _type_key = "ArrayIndexPath"; TVM_DECLARE_FINAL_OBJECT_INFO(ArrayIndexPathNode, ObjectPathNode); @@ -218,9 +219,9 @@ class ArrayIndexPath : public ObjectPath { class MissingArrayElementPathNode : public ObjectPathNode { public: /*! \brief Index of the array element that is missing. */ - size_t index; + int32_t index; - explicit MissingArrayElementPathNode(ObjectPathNode* parent, size_t index); + explicit MissingArrayElementPathNode(const ObjectPathNode* parent, int32_t index); static constexpr const char* _type_key = "MissingArrayElementPath"; TVM_DECLARE_FINAL_OBJECT_INFO(MissingArrayElementPathNode, ObjectPathNode); @@ -242,7 +243,7 @@ class MapValuePathNode : public ObjectPathNode { /*! \brief Key of the map entry that is being accessed */ ObjectRef key; - explicit MapValuePathNode(ObjectPathNode* parent, ObjectRef key); + explicit MapValuePathNode(const ObjectPathNode* parent, ObjectRef key); static constexpr const char* _type_key = "MapValuePath"; TVM_DECLARE_FINAL_OBJECT_INFO(MapValuePathNode, ObjectPathNode); @@ -261,7 +262,7 @@ class MapValuePath : public ObjectPath { class MissingMapEntryPathNode : public ObjectPathNode { public: - explicit MissingMapEntryPathNode(ObjectPathNode* parent); + explicit MissingMapEntryPathNode(const ObjectPathNode* parent); static constexpr const char* _type_key = "MissingMapEntryPath"; TVM_DECLARE_FINAL_OBJECT_INFO(MissingMapEntryPathNode, ObjectPathNode); diff --git a/src/node/object_path.cc b/src/node/object_path.cc index 4596f139c116..6cdaec48759c 100644 --- a/src/node/object_path.cc +++ b/src/node/object_path.cc @@ -31,12 +31,18 @@ namespace tvm { // ============== ObjectPathNode ============== -ObjectPathNode::ObjectPathNode(ObjectPathNode* parent) +ObjectPathNode::ObjectPathNode(const ObjectPathNode* parent) : parent_(GetRef(parent)), length_(parent == nullptr ? 1 : parent->length_ + 1) {} // --- GetParent --- -ObjectPath ObjectPathNode::GetParent() const { return Downcast(parent_); } +Optional ObjectPathNode::GetParent() const { + if (parent_ == nullptr) { + return NullOpt; + } else { + return Downcast(parent_); + } +} TVM_REGISTER_GLOBAL("node.ObjectPathGetParent").set_body_typed([](const ObjectPath& path) { return path->GetParent(); @@ -44,22 +50,21 @@ TVM_REGISTER_GLOBAL("node.ObjectPathGetParent").set_body_typed([](const ObjectPa // --- Length --- -size_t ObjectPathNode::Length() const { return length_; } +int32_t ObjectPathNode::Length() const { return length_; } TVM_REGISTER_GLOBAL("node.ObjectPathLength").set_body_typed([](const ObjectPath& path) { - return static_cast(path->Length()); + return path->Length(); }); // --- GetPrefix --- -ObjectPath ObjectPathNode::GetPrefix(size_t length) const { - if (length > Length()) { - throw std::out_of_range("Attempted to get a prefix longer than the path itself"); - } +ObjectPath ObjectPathNode::GetPrefix(int32_t length) const { + CHECK_GE(length, 1) << "IndexError: Prefix length must be at least 1"; + CHECK_LE(length, Length()) << "IndexError: Attempted to get a prefix longer than the path itself"; const ObjectPathNode* node = this; - size_t suffix_len = Length() - length; - for (size_t i = 0; i < suffix_len; ++i) { + int32_t suffix_len = Length() - length; + for (int32_t i = 0; i < suffix_len; ++i) { node = node->ParentNode(); } @@ -67,21 +72,12 @@ ObjectPath ObjectPathNode::GetPrefix(size_t length) const { } TVM_REGISTER_GLOBAL("node.ObjectPathGetPrefix") - .set_body_typed([](const ObjectPath& path, int64_t length) { - if (length < 0) { - throw std::out_of_range("Prefix length can't be negative"); - } - return path->GetPrefix(static_cast(length)); - }); + .set_body_typed([](const ObjectPath& path, int64_t length) { return path->GetPrefix(length); }); // --- IsPrefixOf --- bool ObjectPathNode::IsPrefixOf(const ObjectPath& other) const { - if (!other.defined()) { - return false; - } - - size_t this_len = Length(); + int32_t this_len = Length(); if (this_len > other->Length()) { return false; } @@ -93,7 +89,7 @@ TVM_REGISTER_GLOBAL("node.ObjectPathIsPrefixOf") // --- Attr --- -ObjectPath ObjectPathNode::Attr(const char* attr_key) { +ObjectPath ObjectPathNode::Attr(const char* attr_key) const { if (attr_key != nullptr) { return ObjectPath(make_object(this, attr_key)); } else { @@ -101,9 +97,9 @@ ObjectPath ObjectPathNode::Attr(const char* attr_key) { } } -ObjectPath ObjectPathNode::Attr(String attr_key) { +ObjectPath ObjectPathNode::Attr(Optional attr_key) const { if (attr_key.defined()) { - return ObjectPath(make_object(this, attr_key)); + return ObjectPath(make_object(this, attr_key.value())); } else { return ObjectPath(make_object(this)); } @@ -116,27 +112,27 @@ TVM_REGISTER_GLOBAL("node.ObjectPathAttr") // --- ArrayIndex --- -ObjectPath ObjectPathNode::ArrayIndex(size_t index) { +ObjectPath ObjectPathNode::ArrayIndex(int32_t index) const { return ObjectPath(make_object(this, index)); } TVM_REGISTER_GLOBAL("node.ObjectPathArrayIndex") - .set_body_typed([](const ObjectPath& path, size_t index) { return path->ArrayIndex(index); }); + .set_body_typed([](const ObjectPath& path, int32_t index) { return path->ArrayIndex(index); }); // --- MissingArrayElement --- -ObjectPath ObjectPathNode::MissingArrayElement(size_t index) { +ObjectPath ObjectPathNode::MissingArrayElement(int32_t index) const { return ObjectPath(make_object(this, index)); } TVM_REGISTER_GLOBAL("node.ObjectPathMissingArrayElement") - .set_body_typed([](const ObjectPath& path, size_t index) { + .set_body_typed([](const ObjectPath& path, int32_t index) { return path->MissingArrayElement(index); }); // --- MapValue --- -ObjectPath ObjectPathNode::MapValue(ObjectRef key) { +ObjectPath ObjectPathNode::MapValue(ObjectRef key) const { return ObjectPath(make_object(this, std::move(key))); } @@ -147,7 +143,7 @@ TVM_REGISTER_GLOBAL("node.ObjectPathMapValue") // --- MissingMapEntry --- -ObjectPath ObjectPathNode::MissingMapEntry() { +ObjectPath ObjectPathNode::MissingMapEntry() const { return ObjectPath(make_object(this)); } @@ -229,7 +225,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(PrintObjec // ----- AttributeAccess ----- -AttributeAccessPathNode::AttributeAccessPathNode(ObjectPathNode* parent, String attr_key) +AttributeAccessPathNode::AttributeAccessPathNode(const ObjectPathNode* parent, String attr_key) : ObjectPathNode(parent), attr_key(std::move(attr_key)) {} bool AttributeAccessPathNode::LastNodeEqual(const ObjectPathNode* other) const { @@ -244,7 +240,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) // ----- UnknownAttributeAccess ----- -UnknownAttributeAccessPathNode::UnknownAttributeAccessPathNode(ObjectPathNode* parent) +UnknownAttributeAccessPathNode::UnknownAttributeAccessPathNode(const ObjectPathNode* parent) : ObjectPathNode(parent) {} bool UnknownAttributeAccessPathNode::LastNodeEqual(const ObjectPathNode* other) const { @@ -261,7 +257,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) // ----- ArrayIndexPath ----- -ArrayIndexPathNode::ArrayIndexPathNode(ObjectPathNode* parent, size_t index) +ArrayIndexPathNode::ArrayIndexPathNode(const ObjectPathNode* parent, int32_t index) : ObjectPathNode(parent), index(index) {} bool ArrayIndexPathNode::LastNodeEqual(const ObjectPathNode* other) const { @@ -275,7 +271,8 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(Prin // ----- MissingArrayElement ----- -MissingArrayElementPathNode::MissingArrayElementPathNode(ObjectPathNode* parent, size_t index) +MissingArrayElementPathNode::MissingArrayElementPathNode(const ObjectPathNode* parent, + int32_t index) : ObjectPathNode(parent), index(index) {} bool MissingArrayElementPathNode::LastNodeEqual(const ObjectPathNode* other) const { @@ -292,7 +289,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) // ----- MapValue ----- -MapValuePathNode::MapValuePathNode(ObjectPathNode* parent, ObjectRef key) +MapValuePathNode::MapValuePathNode(const ObjectPathNode* parent, ObjectRef key) : ObjectPathNode(parent), key(std::move(key)) {} bool MapValuePathNode::LastNodeEqual(const ObjectPathNode* other) const { @@ -310,7 +307,8 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(PrintO // ----- MissingMapEntry ----- -MissingMapEntryPathNode::MissingMapEntryPathNode(ObjectPathNode* parent) : ObjectPathNode(parent) {} +MissingMapEntryPathNode::MissingMapEntryPathNode(const ObjectPathNode* parent) + : ObjectPathNode(parent) {} bool MissingMapEntryPathNode::LastNodeEqual(const ObjectPathNode* other) const { return true; } diff --git a/tests/python/unittest/test_object_path.py b/tests/python/unittest/test_object_path.py index e48542fef702..ef526350ea1e 100644 --- a/tests/python/unittest/test_object_path.py +++ b/tests/python/unittest/test_object_path.py @@ -82,7 +82,6 @@ def test_path_missing_map_entry(): "a, b, expected", [ (ObjectPath.root(), ObjectPath.root(), True), - (ObjectPath.root(), None, False), (ObjectPath.root(), ObjectPath.root().attr("foo"), True), (ObjectPath.root().attr("foo"), ObjectPath.root(), False), (ObjectPath.root().attr("foo"), ObjectPath.root().attr("foo"), True), @@ -135,23 +134,19 @@ def test_path_get_prefix(): assert p2.parent == p1 assert p1.parent is None - assert p2[:0] is None assert p2[:1] == p1 - assert p2[:-2] is None assert p2[:-1] == p1 - assert p3[:0] is None assert p3[:1] == p1 assert p3[:2] == p2 assert p3[:3] == p3 - assert p3[:-3] is None assert p3[:-2] == p1 assert p3[:-1] == p2 - with pytest.raises(tvm._ffi.base.TVMError) as e: + with pytest.raises(IndexError) as e: p3[:-4] - assert "Prefix length can't be negative" in str(e.value) + assert "Prefix length must be at least 1" in str(e.value) - with pytest.raises(tvm._ffi.base.TVMError) as e: + with pytest.raises(IndexError) as e: p3[:4] assert "Attempted to get a prefix longer than the path itself" in str(e.value) From 601252569107fbed944049c1214147c7f03c45ba Mon Sep 17 00:00:00 2001 From: Greg Bonik Date: Tue, 12 Jul 2022 11:05:33 -0700 Subject: [PATCH 4/5] Make python API consistent with C++ --- python/tvm/runtime/object_path.py | 11 ++--------- tests/python/unittest/test_object_path.py | 15 ++++++--------- 2 files changed, 8 insertions(+), 18 deletions(-) diff --git a/python/tvm/runtime/object_path.py b/python/tvm/runtime/object_path.py index 3c3798cf03fc..3eabce1f8694 100644 --- a/python/tvm/runtime/object_path.py +++ b/python/tvm/runtime/object_path.py @@ -67,15 +67,8 @@ def parent(self) -> "ObjectPath": def __len__(self) -> int: return _ffi_node_api.ObjectPathLength(self) - def __getitem__(self, slc) -> "ObjectPath": - if not isinstance(slc, slice): - raise TypeError("ObjectPath can only be indexed with a slice") - if slc.start is not None or slc.step is not None: - raise ValueError( - "ObjectPath can only be indexed with a slice" " of the form [:prefix_len]" - ) - prefix_len = slc.stop if slc.stop >= 0 else len(self) + slc.stop - return _ffi_node_api.ObjectPathGetPrefix(self, prefix_len) + def get_prefix(self, length) -> "ObjectPath": + return _ffi_node_api.ObjectPathGetPrefix(self, length) def is_prefix_of(self, other) -> "ObjectPath": return _ffi_node_api.ObjectPathIsPrefixOf(self, other) diff --git a/tests/python/unittest/test_object_path.py b/tests/python/unittest/test_object_path.py index ef526350ea1e..f849c129df59 100644 --- a/tests/python/unittest/test_object_path.py +++ b/tests/python/unittest/test_object_path.py @@ -134,19 +134,16 @@ def test_path_get_prefix(): assert p2.parent == p1 assert p1.parent is None - assert p2[:1] == p1 - assert p2[:-1] == p1 + assert p2.get_prefix(1) == p1 - assert p3[:1] == p1 - assert p3[:2] == p2 - assert p3[:3] == p3 - assert p3[:-2] == p1 - assert p3[:-1] == p2 + assert p3.get_prefix(1) == p1 + assert p3.get_prefix(2) == p2 + assert p3.get_prefix(3) == p3 with pytest.raises(IndexError) as e: - p3[:-4] + p3.get_prefix(0) assert "Prefix length must be at least 1" in str(e.value) with pytest.raises(IndexError) as e: - p3[:4] + p3.get_prefix(4) assert "Attempted to get a prefix longer than the path itself" in str(e.value) From 41015e9d7fad769345d7c004d61f23eb587ac994 Mon Sep 17 00:00:00 2001 From: Greg Bonik Date: Wed, 13 Jul 2022 11:12:27 -0700 Subject: [PATCH 5/5] Replace set_body_typed with set_body_method when possible --- src/node/object_path.cc | 38 ++++++++++++++------------------------ 1 file changed, 14 insertions(+), 24 deletions(-) diff --git a/src/node/object_path.cc b/src/node/object_path.cc index 6cdaec48759c..9c49daa8c376 100644 --- a/src/node/object_path.cc +++ b/src/node/object_path.cc @@ -44,17 +44,14 @@ Optional ObjectPathNode::GetParent() const { } } -TVM_REGISTER_GLOBAL("node.ObjectPathGetParent").set_body_typed([](const ObjectPath& path) { - return path->GetParent(); -}); +TVM_REGISTER_GLOBAL("node.ObjectPathGetParent") + .set_body_method(&ObjectPathNode::GetParent); // --- Length --- int32_t ObjectPathNode::Length() const { return length_; } -TVM_REGISTER_GLOBAL("node.ObjectPathLength").set_body_typed([](const ObjectPath& path) { - return path->Length(); -}); +TVM_REGISTER_GLOBAL("node.ObjectPathLength").set_body_method(&ObjectPathNode::Length); // --- GetPrefix --- @@ -72,7 +69,7 @@ ObjectPath ObjectPathNode::GetPrefix(int32_t length) const { } TVM_REGISTER_GLOBAL("node.ObjectPathGetPrefix") - .set_body_typed([](const ObjectPath& path, int64_t length) { return path->GetPrefix(length); }); + .set_body_method(&ObjectPathNode::GetPrefix); // --- IsPrefixOf --- @@ -85,7 +82,7 @@ bool ObjectPathNode::IsPrefixOf(const ObjectPath& other) const { } TVM_REGISTER_GLOBAL("node.ObjectPathIsPrefixOf") - .set_body_typed([](const ObjectPath& a, const ObjectPath& b) { return a->IsPrefixOf(b); }); + .set_body_method(&ObjectPathNode::IsPrefixOf); // --- Attr --- @@ -106,8 +103,8 @@ ObjectPath ObjectPathNode::Attr(Optional attr_key) const { } TVM_REGISTER_GLOBAL("node.ObjectPathAttr") - .set_body_typed([](const ObjectPath& path, Optional attr_key) { - return path->Attr(attr_key.defined() ? attr_key.value() : String(nullptr)); + .set_body_typed([](const ObjectPath& object_path, Optional attr_key) { + return object_path->Attr(attr_key); }); // --- ArrayIndex --- @@ -117,7 +114,7 @@ ObjectPath ObjectPathNode::ArrayIndex(int32_t index) const { } TVM_REGISTER_GLOBAL("node.ObjectPathArrayIndex") - .set_body_typed([](const ObjectPath& path, int32_t index) { return path->ArrayIndex(index); }); + .set_body_method(&ObjectPathNode::ArrayIndex); // --- MissingArrayElement --- @@ -126,9 +123,7 @@ ObjectPath ObjectPathNode::MissingArrayElement(int32_t index) const { } TVM_REGISTER_GLOBAL("node.ObjectPathMissingArrayElement") - .set_body_typed([](const ObjectPath& path, int32_t index) { - return path->MissingArrayElement(index); - }); + .set_body_method(&ObjectPathNode::MissingArrayElement); // --- MapValue --- @@ -137,9 +132,7 @@ ObjectPath ObjectPathNode::MapValue(ObjectRef key) const { } TVM_REGISTER_GLOBAL("node.ObjectPathMapValue") - .set_body_typed([](const ObjectPath& path, const ObjectRef& key) { - return path->MapValue(key); - }); + .set_body_method(&ObjectPathNode::MapValue); // --- MissingMapEntry --- @@ -147,9 +140,8 @@ ObjectPath ObjectPathNode::MissingMapEntry() const { return ObjectPath(make_object(this)); } -TVM_REGISTER_GLOBAL("node.ObjectPathMissingMapEntry").set_body_typed([](const ObjectPath& path) { - return path->MissingMapEntry(); -}); +TVM_REGISTER_GLOBAL("node.ObjectPathMissingMapEntry") + .set_body_method(&ObjectPathNode::MissingMapEntry); // --- PathsEqual ---- @@ -176,9 +168,7 @@ bool ObjectPathNode::PathsEqual(const ObjectPath& other) const { } TVM_REGISTER_GLOBAL("node.ObjectPathEqual") - .set_body_typed([](const ObjectPath& lhs, const ObjectPath& rhs) { - return lhs->PathsEqual(rhs); - }); + .set_body_method(&ObjectPathNode::PathsEqual); // --- Repr --- @@ -209,7 +199,7 @@ const ObjectPathNode* ObjectPathNode::ParentNode() const { /* static */ ObjectPath ObjectPath::Root() { return ObjectPath(make_object()); } -TVM_REGISTER_GLOBAL("node.ObjectPathRoot").set_body_typed([]() { return ObjectPath::Root(); }); +TVM_REGISTER_GLOBAL("node.ObjectPathRoot").set_body_typed(ObjectPath::Root); // ============== Individual path classes ==============