From 85e5d3a3d9d23dc8b7f9eb28128855031b536e5d Mon Sep 17 00:00:00 2001 From: Lite Ye Date: Thu, 11 Aug 2022 19:20:06 -0400 Subject: [PATCH] Add IRDocsifier Co-authored-by: Greg Bonik Fix segfault at the end of execution Add doc for ContextManager Fix lint Fix lint --- include/tvm/script/printer/ir_docsifier.h | 189 +++++++++++++++++ .../script/printer/traced_object_functor.h | 21 ++ include/tvm/support/with.h | 30 +++ python/tvm/script/printer/ir_docsifier.py | 198 ++++++++++++++++++ src/script/printer/ir_docsifier.cc | 76 +++++++ src/script/printer/traced_object_functor.cc | 10 + .../cpp/tvmscript_printer_irdocsifier_test.cc | 112 ++++++++++ ...ript_printer_traced_object_functor_test.cc | 14 ++ .../test_tvmscript_printer_irdocsifier.py | 111 ++++++++++ 9 files changed, 761 insertions(+) create mode 100644 include/tvm/script/printer/ir_docsifier.h create mode 100644 python/tvm/script/printer/ir_docsifier.py create mode 100644 src/script/printer/ir_docsifier.cc create mode 100644 tests/cpp/tvmscript_printer_irdocsifier_test.cc create mode 100644 tests/python/unittest/test_tvmscript_printer_irdocsifier.py diff --git a/include/tvm/script/printer/ir_docsifier.h b/include/tvm/script/printer/ir_docsifier.h new file mode 100644 index 000000000000..c882cf1a0f90 --- /dev/null +++ b/include/tvm/script/printer/ir_docsifier.h @@ -0,0 +1,189 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_SCRIPT_PRINTER_IR_DOCSIFIER_H_ +#define TVM_SCRIPT_PRINTER_IR_DOCSIFIER_H_ + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace tvm { +namespace script { +namespace printer { + +using WithCtx = With; + +/*! + * \brief IRDocsifier is the top-level interface in the IR->Doc process. + * + * It provides methods to convert IR node object to Doc, operate on Frame + * objects and change dispatch tokens. + * + * Example usage: + * \code + * TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + * .set_dispatch([](TracedObject obj, IRDocsifier p) { return IdDoc("x"); }); + * + * TracedObject var = ...; + * IRDocsifier p; + * p->AsDoc(var); // returns an IdDoc("x") + * \endcode + * + */ +class IRDocsifierNode : public Object { + public: + /*! + * \brief The var table to use during the printing process. + * \sa VarTableNode + */ + VarTable vars; + /*! + * \brief The stack of frames. + * \sa FrameNode + */ + Array frames; + /*! + * \brief The stack of dispatch tokens. + * + * The dispatch token on the top decides which dispatch function to use + * when converting IR node object to Doc. + */ + Array dispatch_tokens; + /*! + * \brief This map connects IR dipatch token to the name of identifier. + */ + Map ir_prefix; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("vars", &vars); + v->Visit("frames", &frames); + v->Visit("dispatch_tokens", &dispatch_tokens); + v->Visit("ir_prefix", &ir_prefix); + } + + static constexpr const char* _type_key = "script.printer.IRDocsifier"; + TVM_DECLARE_FINAL_OBJECT_INFO(IRDocsifierNode, Object); + + public: + /*! + * \brief Transform the input object into TDoc. + * \param obj The object to be transformed. + * + * \return The Doc object. + */ + template + TDoc AsDoc(const TracedObject& obj) const { + auto result = Downcast(AsDocImpl(obj)); + result->source_paths.push_back(obj.GetPath()); + return result; + } + + /*! + * \brief Helper method to transform object into ExprDoc. + * \param obj The object to be transformed. + * + * \return The ExprDoc object. + */ + ExprDoc AsExprDoc(const TracedObject& obj) { return AsDoc(obj); } + + /*! + * \brief Push a new dispatch token into the stack + * \details The top dispatch token decides which dispatch table to use + * when printing Object. This method returns a RAII guard which + * pops the token when going out of the scope. + * + * \param token The dispatch token to push. + * + * \return A RAII guard to pop dispatch token when going out of scope. + */ + WithCtx WithDispatchToken(const String& token) { + this->dispatch_tokens.push_back(token); + return WithCtx(nullptr, [this]() { this->dispatch_tokens.pop_back(); }); + } + + /*! + * \brief Push a new frame the stack + * \details Frame contains the contextual information that's needed during printing, + * for example, variables in the scope. This method returns a RAII guard which + * pops the frame and call the cleanup method of frame when going out of the scope. + * + * \param frame The frame to push. + * + * \return A RAII guard to pop frame and call the exit method of frame + * when going out of scope + */ + WithCtx WithFrame(const Frame& frame) { + frame->EnterWithScope(); + this->frames.push_back(frame); + return WithCtx(nullptr, [this, pushed_frame = frame]() { + Frame last_frame = this->frames.back(); + ICHECK_EQ(last_frame, pushed_frame); + this->frames.pop_back(); + last_frame->ExitWithScope(); + }); + } + + /*! + * \brief Get the top frame with type FrameType + * \tparam FrameType The type of frame to get. + */ + template + Optional GetFrame() const { + for (auto it = frames.rbegin(); it != frames.rend(); ++it) { + if (const auto* f = (*it).as()) { + return GetRef(f); + } + } + return NullOpt; + } + + private: + Doc AsDocImpl(const TracedObject& obj) const; +}; + +/*! + * \brief Reference type of IRDocsifierNode. + */ +class IRDocsifier : public ObjectRef { + public: + /*! + * \brief Create a IRDocsifier. + * \param ir_prefix The ir_prefix to use for this IRDocsifier. + */ + explicit IRDocsifier(Map ir_prefix); + + using FType = TracedObjectFunctor; + /*! + * \brief The registration table for IRDocsifier. + */ + TVM_DLL static FType& vtable(); + + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(IRDocsifier, ObjectRef, IRDocsifierNode); +}; + +} // namespace printer +} // namespace script +} // namespace tvm + +#endif // TVM_SCRIPT_PRINTER_IR_DOCSIFIER_H_ diff --git a/include/tvm/script/printer/traced_object_functor.h b/include/tvm/script/printer/traced_object_functor.h index 05fbbf79f2ee..6caaf8a6e0d5 100644 --- a/include/tvm/script/printer/traced_object_functor.h +++ b/include/tvm/script/printer/traced_object_functor.h @@ -90,6 +90,15 @@ const runtime::PackedFunc& GetDispatchFunction(const DispatchTable& dispatch_tab void SetDispatchFunction(DispatchTable* dispatch_table, const String& token, uint32_t type_index, runtime::PackedFunc f); +/*! + * \brief Remove function from dispatch table. + * \param dispatch_table The dispatch table. + * \param token The dispatch token. + * \param type_index The TVM object type index for the dispatch function to be removed. + */ +void RemoveDispatchFunction(DispatchTable* dispatch_table, const String& token, + uint32_t type_index); + constexpr const char* kDefaultDispatchToken = ""; /*! @@ -173,6 +182,18 @@ class TracedObjectFunctor { return set_dispatch(kDefaultDispatchToken, std::forward(f)); } + /*! + * \brief Remove dispatch function + * \param token The dispatch token. + * \param type_index The TVM object type index for the dispatch function to be removed. + * + * This is useful when dispatch function comes from other language's runtime, and + * those function should be removed before that language runtime shuts down. + */ + void remove_dispatch(String token, uint32_t type_index) { + RemoveDispatchFunction(&dispatch_table_, token, type_index); + } + private: DispatchTable dispatch_table_; }; diff --git a/include/tvm/support/with.h b/include/tvm/support/with.h index 3651e05e744c..5415c40991be 100644 --- a/include/tvm/support/with.h +++ b/include/tvm/support/with.h @@ -27,6 +27,7 @@ #include +#include #include namespace tvm { @@ -80,5 +81,34 @@ class With { ContextType ctx_; }; +/*! + * \brief A context type that delegates EnterWithScope and ExitWithScope + * to user-provided functions. + */ +class ContextManager { + public: + /*! + * \brief Constructor of ContextManager. + * \param f_enter The function to call when entering scope. If it's nullptr, do nothing when + * entering. + * \param f_exit The function to call when exiting scope. If it's nullptr, do nothing + * when exiting. + */ + template + explicit ContextManager(FEnter f_enter, FExit f_exit) : f_enter_(f_enter), f_exit_(f_exit) {} + + private: + void EnterWithScope() { + if (f_enter_) f_enter_(); + } + void ExitWithScope() { + if (f_exit_) f_exit_(); + } + std::function f_enter_; + std::function f_exit_; + template + friend class With; +}; + } // namespace tvm #endif // TVM_SUPPORT_WITH_H_ diff --git a/python/tvm/script/printer/ir_docsifier.py b/python/tvm/script/printer/ir_docsifier.py new file mode 100644 index 000000000000..16f3ab62ecab --- /dev/null +++ b/python/tvm/script/printer/ir_docsifier.py @@ -0,0 +1,198 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +IRDocsifier is the top-level interface in the process of transforming +IR graph into Doc tree, during printing IR graph as TVMScript code. +""" + +import atexit +from contextlib import ExitStack, contextmanager +from typing import Callable, Dict, Generator, Mapping, Optional, Sequence, Set, Tuple, Type, TypeVar + +from tvm._ffi import get_object_type_index, register_object +from tvm.runtime import Object, ObjectPath + +from . import _ffi_api +from .doc import Doc +from .frame import Frame +from .var_table import VarTable + +_REGISTERED_TYPES: Set[Tuple[str, int]] = set() # {(dispatch_token, type_index)} + + +def _cleanup_dispatch_function(): + for dispatch_token, type_index in _REGISTERED_TYPES: + _ffi_api.IRDocsifierRemoveDispatch(dispatch_token, type_index) # type: ignore # pylint: disable=no-member + + +_CLEANUP_REGISTERED = False + + +def _ensure_cleanup_function_registered(): + """ + Add a cleanup function to be called on interpreter termination, + to remove all dispatch functions registered on the Python side. + + Without cleaning up those dispatch functions, program will segfault + on termination. It's because dispatch functions are referenced from the + static memory of libtvm, thus they will be cleaned up at the very end, + making calls to Py_DecRef after Python interpreter terminates. + """ + global _CLEANUP_REGISTERED # pylint: disable=global-statement + + if not _CLEANUP_REGISTERED: + atexit.register(_cleanup_dispatch_function) + _CLEANUP_REGISTERED = True + + +@register_object("script.printer.IRDocsifier") +class IRDocsifier(Object): + """ + IRDocsifier is the top-level interface in the IR->Doc process. + + It provides methods to convert IR node object to Doc, operate on Frame + objects and change dispatch tokens. + """ + + ir_prefix: Mapping[str, str] + vars: VarTable + frames: Sequence[Frame] + dispatch_tokens: Sequence[str] + + def __init__(self, ir_prefix: Dict[str, str]): + """ + Create a new IRDocsifier. + + Parameters + ---------- + ir_prefix : Dict[str, str] + The ir prefix to use. Key is the IR dispatch token and + value is the name of identifier for this IR's namespace in TVMScript. + """ + self.__init_handle_by_constructor__(_ffi_api.IRDocsifier, ir_prefix) # type: ignore # pylint: disable=no-member + + _TObject = TypeVar("_TObject", bound=Object) + + @classmethod + def set_dispatch( + cls, + node_type: Type[_TObject], + dispatch_function: Callable[[_TObject, "IRDocsifier"], Doc], + dispatch_token: str = "", + ) -> None: + """ + Set the dispatch function to transform a particular IR node type to Doc + + Parameters + ---------- + node_type : Type[_TObject] + The type of object to dispatch on. + dispatch_function : Callable[[_TObject, "IRDocsifier"], Doc] + The dispatch function. It's called to transform IR node object to Doc. + dispatch_token : str + Function will only be called when this dispatch_token is the same as the one + on the top of IRDocsifier's dispatch_tokens stack. An empty dispatch token + means registering as default dispatch function, which will be called when + there is no dispatch function registered with the current dispatch token. + """ + type_index = get_object_type_index(node_type) + if type_index is None: + raise TypeError(f"{type(node_type)} is not a registered TVM object type.") + + _ensure_cleanup_function_registered() + _ffi_api.IRDocsifierSetDispatch( # type: ignore # pylint: disable=no-member + dispatch_token, type_index, dispatch_function + ) + _REGISTERED_TYPES.add((dispatch_token, type_index)) + + def as_doc(self, obj: Object, object_path: ObjectPath) -> Doc: + """ + Transform the input object into Doc. + + Parameters + ---------- + obj : Object + The IR node object. + object_path : ObjectPath + The object path of this object. It's used for locating diagnostic message. + + Returns + ------- + doc : Doc + The doc for this object. + """ + return _ffi_api.IRDocsifierAsDoc(self, obj, object_path) # type: ignore # pylint: disable=no-member + + def get_frame(self, frame_type: Type[Frame]) -> Optional[Frame]: + """ + Get the top frame with type `frame_type`. + + Parameters + ---------- + frame_type : Type[Frame] + The target frame type. + + Returns + ------- + frame : Optional[Frame] + The frame if found, otherwise None. + """ + for i in range(len(self.frames) - 1, -1, -1): + if isinstance(self.frames[i], frame_type): + return self.frames[i] + return None + + @contextmanager + def dispatch_token(self, token: str): + """ + Push a new dispatch token to the stack. + + Parameters + ---------- + token : str + The token to push. + + Returns + ------- + A context manager that pops this dispatch token when exits. + """ + with ExitStack() as stack: + _ffi_api.IRDocsifierPushDispatchToken(self, token) # type: ignore # pylint: disable=no-member + stack.callback(_ffi_api.IRDocsifierPopDispatchToken, self) # type: ignore # pylint: disable=no-member + yield + + _TFrame = TypeVar("_TFrame", bound=Frame) + + @contextmanager + def frame(self, frame: _TFrame) -> Generator[_TFrame, None, None]: + """ + Push a new frame to the stack. + + Parameters + ---------- + frame : Frame + The frame to push. + + Returns + ------- + A context manager that pops this frame when exits. + """ + with ExitStack() as stack: + stack.enter_context(frame) + _ffi_api.IRDocsifierPushFrame(self, frame) # type: ignore # pylint: disable=no-member + stack.callback(_ffi_api.IRDocsifierPopFrame, self) # type: ignore # pylint: disable=no-member + yield frame diff --git a/src/script/printer/ir_docsifier.cc b/src/script/printer/ir_docsifier.cc new file mode 100644 index 000000000000..7d9ba2352d88 --- /dev/null +++ b/src/script/printer/ir_docsifier.cc @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include +#include +#include +#include + +namespace tvm { +namespace script { +namespace printer { + +Doc IRDocsifierNode::AsDocImpl(const TracedObject& obj) const { + return IRDocsifier::vtable()(dispatch_tokens.back(), obj, GetRef(this)); +} + +IRDocsifier::IRDocsifier(Map ir_prefix) { + auto n = make_object(); + n->ir_prefix = std::move(ir_prefix); + n->dispatch_tokens.push_back(kDefaultDispatchToken); + data_ = std::move(n); +} + +IRDocsifier::FType& IRDocsifier::vtable() { + static IRDocsifier::FType inst; + return inst; +} + +TVM_REGISTER_NODE_TYPE(IRDocsifierNode); +TVM_REGISTER_GLOBAL("script.printer.IRDocsifier").set_body_typed([](Map ir_prefix) { + return IRDocsifier(ir_prefix); +}); +TVM_REGISTER_GLOBAL("script.printer.IRDocsifierAsDoc") + .set_body_typed([](IRDocsifier p, ObjectRef obj, ObjectPath obj_path) { + return p->AsDoc(MakeTraced(obj, obj_path)); + }); + +TVM_REGISTER_GLOBAL("script.printer.IRDocsifierPushDispatchToken") + .set_body_typed([](IRDocsifier p, String token) { p->dispatch_tokens.push_back(token); }); +TVM_REGISTER_GLOBAL("script.printer.IRDocsifierPopDispatchToken").set_body_typed([](IRDocsifier p) { + p->dispatch_tokens.pop_back(); +}); + +TVM_REGISTER_GLOBAL("script.printer.IRDocsifierPushFrame") + .set_body_typed([](IRDocsifier p, Frame frame) { p->frames.push_back(frame); }); +TVM_REGISTER_GLOBAL("script.printer.IRDocsifierPopFrame").set_body_typed([](IRDocsifier p) { + p->frames.pop_back(); +}); + +TVM_REGISTER_GLOBAL("script.printer.IRDocsifierSetDispatch") + .set_body_typed([](String token, uint64_t type_index, runtime::PackedFunc f) { + IRDocsifier::vtable().set_dispatch(token, type_index, std::move(f)); + }); +TVM_REGISTER_GLOBAL("script.printer.IRDocsifierRemoveDispatch") + .set_body_typed([](String token, uint64_t type_index) { + IRDocsifier::vtable().remove_dispatch(token, type_index); + }); +} // namespace printer +} // namespace script +} // namespace tvm diff --git a/src/script/printer/traced_object_functor.cc b/src/script/printer/traced_object_functor.cc index a018099a1de0..43160c7f4be4 100644 --- a/src/script/printer/traced_object_functor.cc +++ b/src/script/printer/traced_object_functor.cc @@ -70,6 +70,16 @@ void SetDispatchFunction(DispatchTable* dispatch_table, const String& token, uin } slot = f; } + +void RemoveDispatchFunction(DispatchTable* dispatch_table, const String& token, + uint32_t type_index) { + std::vector* table = &(*dispatch_table)[token]; + if (table->size() <= type_index) { + return; + } + (*table)[type_index] = nullptr; +} + } // namespace printer } // namespace script } // namespace tvm diff --git a/tests/cpp/tvmscript_printer_irdocsifier_test.cc b/tests/cpp/tvmscript_printer_irdocsifier_test.cc new file mode 100644 index 000000000000..fcdb5ed04e41 --- /dev/null +++ b/tests/cpp/tvmscript_printer_irdocsifier_test.cc @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include +#include +#include + +using namespace tvm; +using namespace tvm::script::printer; + +class TestObjectNode : public Object { + public: + void VisitAttrs(AttrVisitor* v) {} + + static constexpr const char* _type_key = "test.script.printer.irdocsifier.TestObject"; + TVM_DECLARE_FINAL_OBJECT_INFO(TestObjectNode, Object); +}; + +class TestObject : public ObjectRef { + public: + TestObject() : ObjectRef(runtime::make_object()) {} + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(TestObject, ObjectRef, TestObjectNode); +}; + +TVM_REGISTER_NODE_TYPE(TestObjectNode); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch([](TracedObject obj, IRDocsifier p) { return IdDoc("x"); }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("tir", [](TracedObject obj, IRDocsifier p) { return IdDoc("tir"); }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("relax", + [](TracedObject obj, IRDocsifier p) { return IdDoc("relax"); }); + +TEST(PrinterIRDocsifierTest, AsDoc) { + IRDocsifier p(Map{}); + ObjectPath path = ObjectPath::Root(); + TestObject obj; + + IdDoc doc = p->AsDoc(MakeTraced(obj, path)); + + ICHECK_EQ(doc->name, "x"); +} + +TEST(PrinterIRDocsifierTest, AsExprDoc) { + IRDocsifier p(Map{}); + ObjectPath path = ObjectPath::Root(); + TestObject obj; + + ExprDoc doc = p->AsExprDoc(MakeTraced(obj, path)); + + ICHECK_EQ(Downcast(doc)->name, "x"); +} + +TEST(PrinterIRDocsifierTest, WithDispatchToken) { + IRDocsifier p(Map{}); + TracedObject obj = MakeTraced(TestObject(), ObjectPath::Root()); + + ICHECK_EQ(p->AsDoc(obj)->name, "x"); + + { + auto ctx = p->WithDispatchToken("tir"); + ICHECK_EQ(p->AsDoc(obj)->name, "tir"); + + { + auto ctx = p->WithDispatchToken("relax"); + ICHECK_EQ(p->AsDoc(obj)->name, "relax"); + } + + ICHECK_EQ(p->AsDoc(obj)->name, "tir"); + } + + ICHECK_EQ(p->AsDoc(obj)->name, "x"); +} + +TEST(PrinterIRDocsifierTest, WithFrame) { + IRDocsifier p(Map{}); + TestObject obj; + + { + VarDefFrame frame; + auto ctx = p->WithFrame(frame); + ICHECK_EQ(p->frames.size(), 1); + + p->vars->Define(obj, "x", ObjectPath::Root(), frame); + ICHECK(p->vars->IsVarDefined(obj)); + } + ICHECK_EQ(p->frames.size(), 0); + ICHECK(!p->vars->IsVarDefined(obj)); +} diff --git a/tests/cpp/tvmscript_printer_traced_object_functor_test.cc b/tests/cpp/tvmscript_printer_traced_object_functor_test.cc index 3fd52d44aa8c..374eb609b6cb 100644 --- a/tests/cpp/tvmscript_printer_traced_object_functor_test.cc +++ b/tests/cpp/tvmscript_printer_traced_object_functor_test.cc @@ -127,6 +127,20 @@ TEST(TracedObjectFunctorTest, ExtraArg) { ICHECK_EQ(functor("tir", MakeTraced(BarObject(), path), 2), 3); } +TEST(TracedObjectFunctorTest, RemoveDispatchFunction) { + TracedObjectFunctor functor; + ObjectPath path = ObjectPath::Root(); + + functor.set_dispatch([](TracedObject o) -> String { return "Foo"; }); + functor.set_dispatch("tir", [](TracedObject o) -> String { return "Foo tir"; }); + + ICHECK_EQ(functor("", MakeTraced(FooObject(), path)), "Foo"); + ICHECK_EQ(functor("tir", MakeTraced(FooObject(), path)), "Foo tir"); + + functor.remove_dispatch("tir", FooObjectNode::RuntimeTypeIndex()); + ICHECK_EQ(functor("tir", MakeTraced(FooObject(), path)), "Foo"); +} + TEST(TracedObjectFunctorTest, CallWithUnregisteredType) { TracedObjectFunctor functor; ObjectPath path = ObjectPath::Root(); diff --git a/tests/python/unittest/test_tvmscript_printer_irdocsifier.py b/tests/python/unittest/test_tvmscript_printer_irdocsifier.py new file mode 100644 index 000000000000..357a710584c1 --- /dev/null +++ b/tests/python/unittest/test_tvmscript_printer_irdocsifier.py @@ -0,0 +1,111 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest + +from tvm.runtime import ObjectPath +from tvm.script.printer.doc import IdDoc +from tvm.script.printer.frame import MetadataFrame, VarDefFrame +from tvm.script.printer.ir_docsifier import IRDocsifier +from tvm.tir import Var + + +@pytest.fixture +def ir_docsifier(): + """ + Creates an IRDocsifier instance with a special dispatch token. + """ + _ir_docsifier = IRDocsifier({}) + with _ir_docsifier.dispatch_token(f"{__file__}"): + yield _ir_docsifier + + +def _get_id_doc_printer(id_name): + def printer(obj, object_path, ir_docsifier): # pylint: disable=unused-argument + return IdDoc(id_name) + + return printer + + +# Because the dispatch table is global, tests should only set dispatch function under +# unique dispatch token. +IRDocsifier.set_dispatch(Var, _get_id_doc_printer("x"), f"{__file__}") + + +def test_set_dispatch(ir_docsifier): + IRDocsifier.set_dispatch(Var, _get_id_doc_printer("x2"), f"{__file__}-2") + with ir_docsifier.dispatch_token(f"{__file__}-2"): + doc = ir_docsifier.as_doc(Var("x", dtype="int8"), ObjectPath.root()) + assert doc.name == "x2" + + doc = ir_docsifier.as_doc(Var("x", dtype="int8"), ObjectPath.root()) + assert doc.name == "x" + + +def test_as_doc(ir_docsifier): + object_path = ObjectPath.root() + doc = ir_docsifier.as_doc(Var("x", "int8"), ObjectPath.root()) + assert doc.name == "x" + assert list(doc.source_paths) == [object_path] + + +def test_with_dispatch_token(ir_docsifier): + initial_token_count = len(ir_docsifier.dispatch_tokens) + + with ir_docsifier.dispatch_token("tir"): + assert len(ir_docsifier.dispatch_tokens) == initial_token_count + 1 + + assert len(ir_docsifier.dispatch_tokens) == initial_token_count + + +def test_with_frame(ir_docsifier): + initial_frame_count = len(ir_docsifier.frames) + + frame = VarDefFrame() + is_callback_called = False + + def callback(): + nonlocal is_callback_called + is_callback_called = True + + frame.add_exit_callback(callback) + + with ir_docsifier.frame(frame): + assert len(ir_docsifier.frames) == initial_frame_count + 1 + assert not is_callback_called + + assert len(ir_docsifier.frames) == initial_frame_count + assert is_callback_called + + +def test_get_frame(ir_docsifier): + with ir_docsifier.frame(VarDefFrame()) as frame_a: + assert ir_docsifier.get_frame(MetadataFrame) is None + assert ir_docsifier.get_frame(VarDefFrame) == frame_a + + with ir_docsifier.frame(VarDefFrame()) as frame_b: + assert ir_docsifier.get_frame(MetadataFrame) is None + assert ir_docsifier.get_frame(VarDefFrame) == frame_b + + with ir_docsifier.frame(MetadataFrame()) as frame_c: + assert ir_docsifier.get_frame(MetadataFrame) == frame_c + assert ir_docsifier.get_frame(VarDefFrame) == frame_b + + assert ir_docsifier.get_frame(MetadataFrame) is None + assert ir_docsifier.get_frame(VarDefFrame) == frame_b + + assert ir_docsifier.get_frame(MetadataFrame) is None + assert ir_docsifier.get_frame(VarDefFrame) == frame_a