diff --git a/include/tvm/script/printer.h b/include/tvm/script/printer.h deleted file mode 100644 index b0fc54108c92..000000000000 --- a/include/tvm/script/printer.h +++ /dev/null @@ -1,56 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -#ifndef TVM_SCRIPT_PRINTER_H_ -#define TVM_SCRIPT_PRINTER_H_ - -#include -#include - -namespace tvm { -namespace script { -namespace printer { - -/*! - * \brief Print IR graph as TVMScript code - * - * \param root_node The root node to print. - * \param ir_name The dispatch token of the target IR, e.g., "tir", "relax". - * \param ir_prefix The symbol name for TVMScript IR namespaces. For example, {"tir": "T"}. - * \param indent_spaces Number of spaces used for indentation - * \param print_line_numbers Whether to print line numbers - * \param num_context_lines Number of context lines to print around the underlined text - * \param path_to_underline Object path to be underlined - * - * \return the TVMScript code as string. - */ -String Script( // - const ObjectRef& root_node, // - String ir_name, // - Map ir_prefix, // - int indent_spaces = 4, // - bool print_line_numbers = false, // - int num_context_lines = -1, // - Optional path_to_underline = NullOpt // -); - -} // namespace printer -} // namespace script -} // namespace tvm - -#endif // TVM_SCRIPT_PRINTER_H_ diff --git a/include/tvm/script/printer/doc.h b/include/tvm/script/printer/doc.h index 1ee7fd6a7fd4..094d3fdf51df 100644 --- a/include/tvm/script/printer/doc.h +++ b/include/tvm/script/printer/doc.h @@ -22,12 +22,13 @@ #include #include #include -#include namespace tvm { namespace script { namespace printer { +class Doc; + /*! * \brief The base class of all Doc. * @@ -88,15 +89,6 @@ class ExprDocNode : public DocNode { */ ExprDoc Attr(String attr) const; - /*! - * \brief Create a doc representing attribute access on the current ExprDoc - * \param attr The attribute to access. - * - * The ObjectPath of attr will be pushed to the source_path of the returned - * doc. - */ - ExprDoc Attr(TracedObject attr) const; - /*! * \brief Create a doc representing index access on the current ExprDoc * \param indices The indices to access. @@ -259,83 +251,33 @@ class LiteralDoc : public ExprDoc { * \brief Create a LiteralDoc to represent None/null/empty value. */ static LiteralDoc None() { return LiteralDoc(ObjectRef(nullptr)); } - - /*! - * \brief Create a LiteralDoc to represent None/null/empty value. - * \param object_path The source path of the returned Doc. - */ - static LiteralDoc None(ObjectPath object_path) { - return LiteralDoc(ObjectRef(nullptr), object_path); - } - /*! * \brief Create a LiteralDoc to represent integer. * \param v The integer value. */ - static LiteralDoc Int(int v) { return LiteralDoc(IntImm(DataType::Int(64), v)); } - - /*! - * \brief Create a LiteralDoc to represent integer. - * \param v The integer value. - * - * The ObjectPath of v will be pushed to the source_path of the returned doc. - */ - static LiteralDoc Int(const TracedObject& v) { return LiteralDoc(v.Get(), v.GetPath()); } - - /*! - * \brief Create a LiteralDoc to represent integer. - * \param v The integer value. - * - * The ObjectPath of v will be pushed to the source_path of the returned doc. - */ - static LiteralDoc Int(const TracedBasicValue& v) { - return LiteralDoc(IntImm(DataType::Int(64), v.Get()), v.GetPath()); - } + static LiteralDoc Int(int64_t v) { return LiteralDoc(IntImm(DataType::Int(64), v)); } /*! * \brief Create a LiteralDoc to represent boolean. * \param v The boolean value. */ static LiteralDoc Boolean(bool v) { return LiteralDoc(IntImm(DataType::Bool(), v)); } - - /*! - * \brief Create a LiteralDoc to represent boolean. - * \param v The boolean value. - * - * The ObjectPath of v will be pushed to the source_path of the returned doc. - */ - static LiteralDoc Boolean(const TracedBasicValue& v) { - return LiteralDoc(IntImm(DataType::Bool(), v.Get()), v.GetPath()); - } - /*! * \brief Create a LiteralDoc to represent float. * \param v The float value. */ static LiteralDoc Float(double v) { return LiteralDoc(FloatImm(DataType::Float(64), v)); } - - /*! - * \brief Create a LiteralDoc to represent float. - * \param v The float value. - * - * The ObjectPath of v will be pushed to the source_path of the returned doc. - */ - static LiteralDoc Float(const TracedObject& v) { - return LiteralDoc(v.Get(), v.GetPath()); - } - /*! * \brief Create a LiteralDoc to represent string. * \param v The string value. */ static LiteralDoc Str(const String& v) { return LiteralDoc(v); } - /*! * \brief Create a LiteralDoc to represent string. * \param v The string value. - * - * The ObjectPath of v will be pushed to the source_path of the returned doc. */ - static LiteralDoc Str(const TracedObject& v) { return LiteralDoc(v.Get(), v.GetPath()); } + static LiteralDoc DataType(const DLDataType& v) { + return LiteralDoc::Str(runtime::DLDataType2String(v)); + } TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(LiteralDoc, ExprDoc, LiteralDocNode); }; diff --git a/include/tvm/script/printer/doc_printer.h b/include/tvm/script/printer/doc_printer.h deleted file mode 100644 index 04a67a9b8209..000000000000 --- a/include/tvm/script/printer/doc_printer.h +++ /dev/null @@ -1,48 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -#ifndef TVM_SCRIPT_PRINTER_DOC_PRINTER_H_ -#define TVM_SCRIPT_PRINTER_DOC_PRINTER_H_ - -#include - -namespace tvm { -namespace script { -namespace printer { - -/*! - * \brief Convert Doc into Python script. - * - * This function unpacks the DocPrinterOptions into function arguments - * to be FFI friendly. - * - * \param doc Doc to be converted - * \param indent_spaces Number of spaces used for indentation - * \param print_line_numbers Whether to print line numbers - * \param num_context_lines Number of context lines to print around the underlined text - * \param path_to_underline Object path to be underlined - */ -String DocToPythonScript(Doc doc, int indent_spaces = 4, bool print_line_numbers = false, - int num_context_lines = -1, - Optional path_to_underline = NullOpt); - -} // namespace printer -} // namespace script -} // namespace tvm - -#endif // TVM_SCRIPT_PRINTER_DOC_PRINTER_H_ diff --git a/include/tvm/script/printer/frame.h b/include/tvm/script/printer/frame.h deleted file mode 100644 index 407ad16007e9..000000000000 --- a/include/tvm/script/printer/frame.h +++ /dev/null @@ -1,140 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -#ifndef TVM_SCRIPT_PRINTER_FRAME_H_ -#define TVM_SCRIPT_PRINTER_FRAME_H_ - -#include -#include - -#include -#include - -namespace tvm { -namespace script { -namespace printer { - -/*! - * Frame is the core data structure for semantic information - * when printing IR graph into TVMScript code. - */ -class FrameNode : public Object { - public: - void VisitAttrs(tvm::AttrVisitor* v) {} - - virtual ~FrameNode() = default; - - /*! - * \brief Add a callback function to be called when this frame exits. - * \param cb The callback function. It should have signature void(). - */ - template - void AddExitCallback(TCallback&& cb) { - callbacks_.emplace_back(std::forward(cb)); - } - - /*! - * \brief Method that's called when Frame enters the scope. - */ - virtual void EnterWithScope() {} - - /*! - * \brief Method that's called when Frame exits the scope. - */ - virtual void ExitWithScope() { - for (const std::function& callback : callbacks_) { - callback(); - } - callbacks_.clear(); - } - - static constexpr const char* _type_key = "script.printer.Frame"; - TVM_DECLARE_BASE_OBJECT_INFO(FrameNode, Object); - - private: - std::vector> callbacks_; -}; - -/*! - * \brief Reference type of FrameNode - */ -class Frame : public ObjectRef { - protected: - Frame() = default; - - public: - virtual ~Frame() = default; - TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(Frame, ObjectRef, FrameNode); -}; - -/*! - * \brief MetadataFrame contains information like contant parameter array. - */ -class MetadataFrameNode : public FrameNode { - public: - Array metadata; - - void VisitAttrs(tvm::AttrVisitor* v) { - FrameNode::VisitAttrs(v); - v->Visit("metadata", &metadata); - } - - static constexpr const char* _type_key = "script.printer.MetadataFrame"; - TVM_DECLARE_FINAL_OBJECT_INFO(MetadataFrameNode, FrameNode); -}; - -/*! - * \brief Reference type of MetadataFrameNode - */ -class MetadataFrame : public Frame { - public: - MetadataFrame(); - TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(MetadataFrame, Frame, MetadataFrameNode); -}; - -/*! - * \brief VarDefFrame contains information about the free variables that needs to be defined - * at the beginning of the printed snippet. - */ -class VarDefFrameNode : public FrameNode { - public: - Array stmts; - - void VisitAttrs(tvm::AttrVisitor* v) { - FrameNode::VisitAttrs(v); - v->Visit("stmts", &stmts); - } - - static constexpr const char* _type_key = "script.printer.VarDefFrame"; - TVM_DECLARE_FINAL_OBJECT_INFO(VarDefFrameNode, FrameNode); -}; - -/*! - * \brief Reference type of VarDefFrameNode - */ -class VarDefFrame : public Frame { - public: - VarDefFrame(); - TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(VarDefFrame, Frame, VarDefFrameNode); -}; - -} // namespace printer -} // namespace script -} // namespace tvm - -#endif // TVM_SCRIPT_PRINTER_FRAME_H_ diff --git a/include/tvm/script/printer/ir_docsifier.h b/include/tvm/script/printer/ir_docsifier.h index 8945bd6e7a94..e97ddc0234b6 100644 --- a/include/tvm/script/printer/ir_docsifier.h +++ b/include/tvm/script/printer/ir_docsifier.h @@ -19,45 +19,117 @@ #ifndef TVM_SCRIPT_PRINTER_IR_DOCSIFIER_H_ #define TVM_SCRIPT_PRINTER_IR_DOCSIFIER_H_ +#include #include -#include #include -#include -#include -#include -#include -#include +#include + +#include +#include +#include +#include namespace tvm { namespace script { namespace printer { -using WithCtx = With; +//////////////////////// Frame //////////////////////// + +class IRDocsifier; +class IRDocsifierNode; + +/*! + * Frame is the core data structure for semantic information + * when printing IR graph into TVMScript code. + */ +class FrameNode : public Object { + public: + /*! The docs generated in the frame */ + Array stmts; + /*! The corresponding IRDocsifier */ + IRDocsifierNode* d; + /*! The callbacks that are going to be invoked when the frame exits */ + std::vector> callbacks; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("stmts", &stmts); + // `d` is not visited + // `callbacks` is not visited + } + + static constexpr const char* _type_key = "script.printer.Frame"; + TVM_DECLARE_BASE_OBJECT_INFO(FrameNode, Object); + + public: + virtual ~FrameNode() = default; + + /*! + * \brief Add a callback function to be called when this frame exits. + * \param cb The callback function. It should have signature void(). + */ + template + void AddExitCallback(TCallback&& cb) { + callbacks.emplace_back(std::forward(cb)); + } + /*! + * \brief Add a dispatch token to the docsifier, and a callback that pops the token when this + * frame exits. + * \param d The docsifier. + * \param token The token to be added. + */ + void AddDispatchToken(const IRDocsifier& d, const String& token); + /*! + * \brief Method that's called when Frame enters the scope. + */ + virtual void EnterWithScope(); + /*! + * \brief Method that's called when Frame exits the scope. + */ + virtual void ExitWithScope(); +}; + +/*! + * \brief Reference type of FrameNode + */ +class Frame : public ObjectRef { + protected: + Frame() = default; + + public: + virtual ~Frame() = default; + + /*! \brief Method that's called when Frame enters the scope. */ + void EnterWithScope() { get()->EnterWithScope(); } + + /*! \brief Method that's called when Frame exits the scope. */ + void ExitWithScope() { get()->ExitWithScope(); } + + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(Frame, ObjectRef, FrameNode); +}; + +//////////////////////// IRDocsifier //////////////////////// /*! * \brief IRDocsifier is the top-level interface in the IR->Doc process. * * It provides methods to convert IR node object to Doc, operate on Frame * objects and change dispatch tokens. - * - * Example usage: - * \code - * TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - * .set_dispatch([](TracedObject obj, IRDocsifier p) { return IdDoc("x"); }); - * - * TracedObject var = ...; - * IRDocsifier p; - * p->AsDoc(var); // returns an IdDoc("x") - * \endcode - * */ class IRDocsifierNode : public Object { public: + /*! \brief A function that creates the doc for a variable */ + using DocCreator = std::function; + /*! \brief Information about a variable, including its optional name and its doc creator */ + struct VariableInfo { + /*! \brief The creator */ + DocCreator creator; + /*! \brief The name of the variable */ + Optional name; + }; /*! - * \brief The var table to use during the printing process. - * \sa VarTableNode + * \brief This map connects IR dispatch token to the name of identifier. */ - VarTable vars; + Map ir_prefix; /*! * \brief The stack of frames. * \sa FrameNode @@ -70,16 +142,23 @@ class IRDocsifierNode : public Object { * when converting IR node object to Doc. */ Array dispatch_tokens; - /*! - * \brief This map connects IR dipatch token to the name of identifier. - */ - Map ir_prefix; + /*! \brief The IRModule to be docsifier is handling */ + Optional mod; + /*! \brief Mapping from a var to its info */ + std::unordered_map obj2info; + /*! \brief The variable names used already */ + std::unordered_set defined_names; + /*! \brief Common prefixes of variable usages */ + std::unordered_map> common_prefix; void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("vars", &vars); + v->Visit("ir_prefix", &ir_prefix); v->Visit("frames", &frames); v->Visit("dispatch_tokens", &dispatch_tokens); - v->Visit("ir_prefix", &ir_prefix); + v->Visit("mod", &mod); + // `obj2info` is not visited + // `defined_names` is not visited + // `common_prefix` is not visited } static constexpr const char* _type_key = "script.printer.IRDocsifier"; @@ -87,79 +166,68 @@ class IRDocsifierNode : public Object { public: /*! - * \brief Transform the input object into TDoc. - * \param obj The object to be transformed. + * \brief Define variable by name. + * \param obj The variable object. + * \param frame The frame that this variable is defined in. + * \param name_hint The hint for variable name. * - * \return The Doc object. + * \return The id doc for this variable. + * + * This function will rename the variable to avoid name conflict with other variables + * in the table. */ - template - TDoc AsDoc(const TracedObject& obj) const { - auto result = Downcast(AsDocImpl(obj)); - result->source_paths.push_back(obj.GetPath()); - return result; - } + IdDoc Define(const ObjectRef& obj, const Frame& frame, const String& name_hint); /*! - * \brief Helper method to transform object into ExprDoc. - * \param obj The object to be transformed. + * \brief Define variable by doc factory. + * \param obj The variable object. + * \param frame The frame that this variable is defined in. + * \param doc_factory The function to return an ExprDoc object for this variable. * - * \return The ExprDoc object. + * This function is a special form of `Define`. Variable is mapped to ExprDoc rather + * than IdDoc. It's useful when a variable is implicitly defined without a name, like + * the buf->data in TIR, which should be mapped to `AttrDoc(IdDoc(""), "data")`. + * + * This function takes a DocFactory instead of Doc. It's because GetVarDoc needs to + * return a new Doc object every time it's called, as the returned doc will have + * different `source_path`. Currently there isn't a good way to deep copy a TVMObject + * so VarTable needs to call a factory function to get a freshly-constructed Doc object + * every time GetVarDoc is called. */ - ExprDoc AsExprDoc(const TracedObject& obj) { return AsDoc(obj); } + void Define(const ObjectRef& obj, const Frame& frame, DocCreator doc_factory); /*! - * \brief Push a new dispatch token into the stack - * \details The top dispatch token decides which dispatch table to use - * when printing Object. This method returns a RAII guard which - * pops the token when going out of the scope. - * - * \param token The dispatch token to push. + * \brief Get the doc for variable. + * \param obj The variable object. * - * \return A RAII guard to pop dispatch token when going out of scope. + * \return The doc for variable, if it exists in the table. Otherwise it returns NullOpt. */ - WithCtx WithDispatchToken(const String& token) { - this->dispatch_tokens.push_back(token); - return WithCtx(nullptr, [this]() { this->dispatch_tokens.pop_back(); }); - } + Optional GetVarDoc(const ObjectRef& obj) const; /*! - * \brief Push a new frame the stack - * \details Frame contains the contextual information that's needed during printing, - * for example, variables in the scope. This method returns a RAII guard which - * pops the frame and call the cleanup method of frame when going out of the scope. - * - * \param frame The frame to push. + * \brief Check if a variable exists in the table. + * \param obj The variable object. * - * \return A RAII guard to pop frame and call the exit method of frame - * when going out of scope + * \return a boolean for whether variable exists. */ - WithCtx WithFrame(const Frame& frame) { - frame->EnterWithScope(); - this->frames.push_back(frame); - return WithCtx(nullptr, [this, pushed_frame = frame]() { - Frame last_frame = this->frames.back(); - ICHECK_EQ(last_frame, pushed_frame); - this->frames.pop_back(); - last_frame->ExitWithScope(); - }); - } - + bool IsVarDefined(const ObjectRef& obj) const; + /*! \brief Remove the variable defined */ + void RemoveVar(const ObjectRef& obj); /*! - * \brief Get the top frame with type FrameType - * \tparam FrameType The type of frame to get. + * \brief Set the common prefix information of variable usage. + * \param root The root of the AST. + * \param is_var A function that returns true if the given object is considered a variable. */ - template - Optional GetFrame() const { - for (auto it = frames.rbegin(); it != frames.rend(); ++it) { - if (const auto* f = (*it).as()) { - return GetRef(f); - } - } - return NullOpt; - } - - private: - Doc AsDocImpl(const TracedObject& obj) const; + void SetCommonPrefix(const ObjectRef& root, runtime::TypedPackedFunc is_var); + /*! + * \brief Transform the input object into TDoc. + * \param obj The object to be transformed. + * \param path The path to this object. + * + * \return The Doc object. + */ + template + inline TDoc AsDoc(const ObjectRef& obj, const ObjectPath& path) const; }; /*! @@ -167,61 +235,49 @@ class IRDocsifierNode : public Object { */ class IRDocsifier : public ObjectRef { public: + using FType = IRDocsifierFunctor; /*! * \brief Create a IRDocsifier. * \param ir_prefix The ir_prefix to use for this IRDocsifier. */ explicit IRDocsifier(Map ir_prefix); - - using FType = TracedObjectFunctor; - /*! - * \brief The registration table for IRDocsifier. - */ + /*! \brief The registration table for IRDocsifier. */ TVM_DLL static FType& vtable(); TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(IRDocsifier, ObjectRef, IRDocsifierNode); }; -/*! - * \brief A wrapper object to provide injection point for printer of each IR. - * - * For any IR node to be transformed by IRDocsifier, it will be wrapped by RootNodeContainer - * and be dispatched to the corresponding function first. This provides an injection point for - * each IR's printer implemention to add specialized logic, for example, pushing a special - * Frame to the IRDocsifier before doing any IR->Doc transformation. - * - * \code - * TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - * .set_dispatch("relax", [](TracedObject obj, IRDocsifier p) { - * const ObjectRef& root_node = obj.Get()->root_node; - * // For example, relax printer can create a Frame specialized to Relax here - * RelaxGeneralFrame frame; - * auto ctx = p->WithFrame(frame); - * // More specialized logic for your IR. - * return p->AsDoc(MakeTraced(root_node)); - * }); - * \endcode - */ -class RootNodeContainerNode : public Object { - public: - /*! \brief The root node to print. */ - ObjectRef root_node; +//////////////////////// Implementation //////////////////////// - void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("root_node", &root_node); } +inline void FrameNode::EnterWithScope() { + if (d != nullptr) { + d->frames.push_back(GetRef(this)); + } +} - static constexpr const char* _type_key = "script.printer.RootNodeContainer"; - TVM_DECLARE_FINAL_OBJECT_INFO(RootNodeContainerNode, Object); -}; +inline void FrameNode::ExitWithScope() { + for (const std::function& callback : callbacks) { + callback(); + } + callbacks.clear(); + if (d != nullptr) { + d->frames.pop_back(); + } +} -class RootNodeContainer : public ObjectRef { - public: - /*! - * \brief Constructor of RootNodeContainer. - * \param root_node The root node to print. - * */ - explicit RootNodeContainer(ObjectRef root_node); - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(RootNodeContainer, ObjectRef, RootNodeContainerNode); -}; +template +inline TDoc IRDocsifierNode::AsDoc(const ObjectRef& obj, const ObjectPath& path) const { + if (!obj.defined()) { + return Downcast(LiteralDoc::None()); + } + return Downcast( + IRDocsifier::vtable()(dispatch_tokens.back(), obj, path, GetRef(this))); +} + +inline void FrameNode::AddDispatchToken(const IRDocsifier& d, const String& token) { + d->dispatch_tokens.push_back(token); + this->AddExitCallback([doc = d.get()]() { doc->dispatch_tokens.pop_back(); }); +} } // namespace printer } // namespace script diff --git a/include/tvm/script/printer/ir_docsifier_functor.h b/include/tvm/script/printer/ir_docsifier_functor.h new file mode 100644 index 000000000000..d04d8c4d028a --- /dev/null +++ b/include/tvm/script/printer/ir_docsifier_functor.h @@ -0,0 +1,163 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_SCRIPT_PRINTER_IR_DOCSIFIER_FUNCTOR_H_ +#define TVM_SCRIPT_PRINTER_IR_DOCSIFIER_FUNCTOR_H_ + +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace tvm { +namespace script { +namespace printer { + +/*! + * \brief Dynamic dispatch functor based on ObjectPath. + * + * This functor dispatches based on the type of object and the input dispatch token. + */ +template +class IRDocsifierFunctor { + private: + using TSelf = IRDocsifierFunctor; + + template + using IsDispatchFunction = + typename std::is_convertible>; + + public: + /*! + * \brief Call the dispatch function. + * \param token The dispatch token. + * \param obj The object. + * \param args Other args. + * + * \return The return value of the dispatch function + * + * If the TObjectRef isn't registered with the token, it will try to find + * dispatch function for TObjectRef with the default dispatch token (empty string). + */ + template + R operator()(const String& token, TObjectRef obj, Args... args) const { + uint32_t type_index = obj.defined() ? obj->type_index() : 0; + const runtime::PackedFunc* pf = nullptr; + if ((pf = LookupDispatchTable(token, type_index)) != nullptr) { + return (*pf)(obj, args...); + } + if ((pf = LookupDispatchTable("", type_index)) != nullptr) { + return (*pf)(obj, args...); + } + ICHECK(false) << "ObjectFunctor calls un-registered function on type: " + << runtime::Object::TypeIndex2Key(type_index) << " (token: " << token << ")" + << ". ObjectType: " << obj->GetTypeKey() << ". Object: " << obj; + } + + /*! + * \brief Set the dispatch function + * \param token The dispatch token. + * \param type_index The TVM object type index for this dispatch function. + * \param f The dispatch function. + * + * This takes a type-erased packed function as input. It should be used + * through FFI boundary, for example, registering dispatch function from Python. + */ + TSelf& set_dispatch(String token, uint32_t type_index, runtime::PackedFunc f) { + std::vector* table = &dispatch_table_[token]; + if (table->size() <= type_index) { + table->resize(type_index + 1, nullptr); + } + runtime::PackedFunc& slot = (*table)[type_index]; + if (slot != nullptr) { + ICHECK(false) << "Dispatch for type is already registered: " + << runtime::Object::TypeIndex2Key(type_index); + } + slot = f; + return *this; + } + + /*! + * \brief Set the dispatch function + * \param token The dispatch token. + * \param f The dispatch function. + */ + template ::value>> + TSelf& set_dispatch(String token, TCallable f) { + return set_dispatch(token, TObjectRef::ContainerType::RuntimeTypeIndex(), + runtime::TypedPackedFunc(f)); + } + + /*! + * \brief Remove dispatch function + * \param token The dispatch token. + * \param type_index The TVM object type index for the dispatch function to be removed. + * + * This is useful when dispatch function comes from other language's runtime, and + * those function should be removed before that language runtime shuts down. + */ + void remove_dispatch(String token, uint32_t type_index) { + std::vector* table = &dispatch_table_[token]; + if (table->size() <= type_index) { + return; + } + (*table)[type_index] = nullptr; + } + + private: + /*! + * \brief Look up the dispatch table for the given token and type_index. + * \param token The dispatch token. + * \param type_index The TVM object type index. + * \return Returns the functor if the lookup succeeds, nullptr otherwise. + */ + const runtime::PackedFunc* LookupDispatchTable(const String& token, uint32_t type_index) const { + auto it = dispatch_table_.find(token); + if (it == dispatch_table_.end()) { + return nullptr; + } + const std::vector& tab = it->second; + if (type_index >= tab.size()) { + return nullptr; + } + const PackedFunc* f = &tab[type_index]; + if (f->defined()) { + return f; + } else { + return nullptr; + } + } + /* + * This type alias and the following free functions are created to reduce the binary bloat + * from template and also hide implementation details from this header + */ + using DispatchTable = std::unordered_map>; + /*! \brief The dispatch table. */ + DispatchTable dispatch_table_; +}; + +} // namespace printer +} // namespace script +} // namespace tvm +#endif // TVM_SCRIPT_PRINTER_IR_DOCSIFIER_FUNCTOR_H_ diff --git a/include/tvm/script/printer/printer.h b/include/tvm/script/printer/printer.h new file mode 100644 index 000000000000..31abd7d9ec89 --- /dev/null +++ b/include/tvm/script/printer/printer.h @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_SCRIPT_PRINTER_PRINTER_H_ +#define TVM_SCRIPT_PRINTER_PRINTER_H_ + +#include +#include + +#include +#include + +namespace tvm { +namespace script { +namespace printer { + +/*! \brief Default values in the TVMScript printer */ +struct Default { + /*! \brief Default data type of TIR buffer */ + DataType buffer_dtype = DataType::Float(32); + /*! \brief Default data type of integer literals */ + DataType int_dtype = DataType::Int(32); + /*! + * \brief Default data type of float literals. Right now we always print out the explicit type + * of floating point values, so setting it to Void means we do not print without the + * T.float32/T.float64 wrapper. + */ + DataType float_dtype = DataType::Void(); + /*! \brief Returns a singleton of the configuration */ + static Default* Instance(); + static DataType& BufferDType() { return Instance()->buffer_dtype; } + static DataType& IntDType() { return Instance()->int_dtype; } + static DataType& FloatDType() { return Instance()->float_dtype; } +}; + +/*! + * \brief The entry method for TVMScript printing + * \param obj The object to be printed + * \param ir_prefix The prefix of IR nodes + * \param indent_spaces Number of spaces used for indentation + * \param print_line_numbers Whether to print line numbers + * \param num_context_lines Number of context lines to print around the underlined text + * \param path_to_underline Object path to be underlined + * \return The TVMScript text format + */ +String Script(ObjectRef obj, // + Map ir_prefix = {{"ir", "I"}, {"tir", "T"}}, // + int indent_spaces = 4, // + bool print_line_numbers = false, // + int num_context_lines = -1, // + Optional path_to_underline = NullOpt); + +/*! + * \brief Convert Doc into Python script. + * \param doc Doc to be converted + * \param indent_spaces Number of spaces used for indentation + * \param print_line_numbers Whether to print line numbers + * \param num_context_lines Number of context lines to print around the underlined text + * \param path_to_underline Object path to be underlined + */ +String DocToPythonScript(Doc doc, // + int indent_spaces = 4, // + bool print_line_numbers = false, // + int num_context_lines = -1, // + Optional path_to_underline = NullOpt); + +} // namespace printer +} // namespace script +} // namespace tvm + +#endif // TVM_SCRIPT_PRINTER_PRINTER_H_ diff --git a/include/tvm/script/printer/traced_object.h b/include/tvm/script/printer/traced_object.h deleted file mode 100644 index cb63c31cd4a5..000000000000 --- a/include/tvm/script/printer/traced_object.h +++ /dev/null @@ -1,484 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file tvm/script/printer/traced_object.h - * Wrappers around TVM objects that also store an ObjectPath from some "root" object - * to the wrapper object. - */ - -#ifndef TVM_SCRIPT_PRINTER_TRACED_OBJECT_H_ -#define TVM_SCRIPT_PRINTER_TRACED_OBJECT_H_ - -#include -#include -#include - -#include -#include - -namespace tvm { - -template -class TracedObject; -template -class TracedMap; -template -class TracedArray; -template -class TracedOptional; -template -class TracedBasicValue; - -namespace detail { - -template ::value> -struct TracedObjectWrapperSelector; - -template -struct TracedObjectWrapperSelector { - using Type = TracedBasicValue; -}; - -template -struct TracedObjectWrapperSelector { - using Type = TracedObject; -}; - -template -struct TracedObjectWrapperSelector, true> { - using Type = TracedMap; -}; - -template -struct TracedObjectWrapperSelector, true> { - using Type = TracedArray; -}; - -template -struct TracedObjectWrapperSelector, true> { - using Type = TracedOptional; -}; - -} // namespace detail - -/*! - * \brief Traced wrapper for regular (non-container) TVM objects. - */ -template -class TracedObject { - using ObjectType = typename RefT::ContainerType; - - public: - using ObjectRefType = RefT; - - // Don't use this direcly. For convenience, call MakeTraced() instead. - explicit TracedObject(const RefT& object_ref, ObjectPath path) - : ref_(object_ref), path_(std::move(path)) {} - - // Implicit conversion from a derived reference class - template - TracedObject(const TracedObject& derived) - : ref_(derived.Get()), path_(derived.GetPath()) {} - - /*! - * \brief Get a traced wrapper for an attribute of the wrapped object. - */ - template - typename detail::TracedObjectWrapperSelector::Type GetAttr(T BaseType::*member_ptr) const { - using WrapperType = typename detail::TracedObjectWrapperSelector::Type; - const ObjectType* node = static_cast(ref_.get()); - const T& attr = node->*member_ptr; - Optional attr_key = ICHECK_NOTNULL(GetAttrKeyByAddress(node, &attr)); - return WrapperType(attr, path_->Attr(attr_key)); - } - - /*! - * \brief Access the wrapped object. - */ - const RefT& Get() const { return ref_; } - - /*! - * \brief Check if the reference to the wrapped object can be converted to `RefU`. - */ - template - bool IsInstance() const { - return ref_->template IsInstance(); - } - - /*! - * \brief Same as Get().defined(). - */ - bool defined() const { return ref_.defined(); } - - /*! - * \brief Convert the wrapped reference type to a subtype. - * - * Throws an exception if IsInstance() is false. - */ - template - TracedObject Downcast() const { - return TracedObject(tvm::runtime::Downcast(ref_), path_); - } - - /*! - * \brief Convert the wrapped reference type to a subtype. - * - * Returns an empty optional if IsInstance() is false. - */ - template - TracedOptional TryDowncast() const { - if (ref_->template IsInstance()) { - return Downcast(); - } else { - return TracedOptional(NullOpt, path_); - } - } - - /*! - * \brief Get the path of the wrapped object. - */ - const ObjectPath& GetPath() const { return path_; } - - private: - RefT ref_; - ObjectPath path_; -}; - -/*! - * \brief Iterator class for TracedMap - */ -template -class TracedMapIterator { - public: - using WrappedV = typename detail::TracedObjectWrapperSelector::Type; - using MapIter = typename Map::iterator; - - using iterator_category = std::bidirectional_iterator_tag; - using difference_type = ptrdiff_t; - using value_type = const std::pair; - using pointer = value_type*; - using reference = value_type; - - explicit TracedMapIterator(MapIter iter, ObjectPath map_path) - : iter_(iter), map_path_(std::move(map_path)) {} - - bool operator==(const TracedMapIterator& other) const { return iter_ == other.iter_; } - - bool operator!=(const TracedMapIterator& other) const { return iter_ != other.iter_; } - - pointer operator->() const = delete; - - reference operator*() const { - auto kv = *iter_; - return std::make_pair(kv.first, WrappedV(kv.second, map_path_->MapValue(kv.first))); - } - - TracedMapIterator& operator++() { - ++iter_; - return *this; - } - - TracedMapIterator operator++(int) { - TracedMapIterator copy = *this; - ++(*this); - return copy; - } - - private: - MapIter iter_; - ObjectPath map_path_; -}; - -/*! - * \brief Traced wrapper for Map objects. - */ -template -class TracedMap { - public: - using WrappedV = typename detail::TracedObjectWrapperSelector::Type; - - using iterator = TracedMapIterator; - - // Don't use this direcly. For convenience, call MakeTraced() instead. - explicit TracedMap(Map map, ObjectPath path) - : map_(std::move(map)), path_(std::move(path)) {} - - /*! - * \brief Get a value by its key, wrapped in a traced wrapper. - */ - WrappedV at(const K& key) const { - auto it = map_.find(key); - ICHECK(it != map_.end()) << "No such key in Map"; - auto kv = *it; - return WrappedV(kv.second, path_->MapValue(kv.first)); - } - - /*! - * \brief Access the wrapped map object. - */ - const Map& Get() const { return map_; } - - /*! - * \brief Get the path of the wrapped object. - */ - const ObjectPath& GetPath() const { return path_; } - - /*! - * \brief Get an iterator to the first item of the map. - */ - iterator begin() const { return iterator(map_.begin(), path_); } - - /*! - * \brief Get an iterator to the end of the map. - */ - iterator end() const { return iterator(map_.end(), path_); } - - /*! - * \brief Returns true iff the wrapped map is empty. - */ - bool empty() const { return map_.empty(); } - - private: - Map map_; - ObjectPath path_; -}; - -/*! - * \brief Iterator class for TracedArray - */ -template -class TracedArrayIterator { - public: - using WrappedT = typename detail::TracedObjectWrapperSelector::Type; - - using difference_type = ptrdiff_t; - using value_type = WrappedT; - using pointer = WrappedT*; - using reference = WrappedT&; - using iterator_category = std::random_access_iterator_tag; - - explicit TracedArrayIterator(Array array, size_t index, ObjectPath array_path) - : array_(array), index_(index), array_path_(array_path) {} - - TracedArrayIterator& operator++() { - ++index_; - return *this; - } - TracedArrayIterator& operator--() { - --index_; - return *this; - } - TracedArrayIterator operator++(int) { - TracedArrayIterator copy = *this; - ++index_; - return copy; - } - TracedArrayIterator operator--(int) { - TracedArrayIterator copy = *this; - --index_; - return copy; - } - - TracedArrayIterator operator+(difference_type offset) const { - return TracedArrayIterator(array_, index_ + offset, array_path_); - } - - TracedArrayIterator operator-(difference_type offset) const { - return TracedArrayIterator(array_, index_ - offset, array_path_); - } - - difference_type operator-(const TracedArrayIterator& rhs) const { return index_ - rhs.index_; } - - bool operator==(TracedArrayIterator other) const { - return array_.get() == other.array_.get() && index_ == other.index_; - } - bool operator!=(TracedArrayIterator other) const { return !(*this == other); } - value_type operator*() const { return WrappedT(array_[index_], array_path_->ArrayIndex(index_)); } - - private: - Array array_; - size_t index_; - ObjectPath array_path_; -}; - -/*! - * \brief Traced wrapper for Array objects. - */ -template -class TracedArray { - public: - using WrappedT = typename detail::TracedObjectWrapperSelector::Type; - - using iterator = TracedArrayIterator; - - // Don't use this direcly. For convenience, call MakeTraced() instead. - explicit TracedArray(Array array, ObjectPath path) - : array_(std::move(array)), path_(std::move(path)) {} - - /*! - * \brief Access the wrapped array object. - */ - const Array& Get() const { return array_; } - - /*! - * \brief Get the path of the wrapped array object. - */ - const ObjectPath& GetPath() const { return path_; } - - /*! - * \brief Get an element by index, wrapped in a traced wrapper. - */ - WrappedT operator[](size_t index) const { - return WrappedT(array_[index], path_->ArrayIndex(index)); - } - - /*! - * \brief Get an iterator to the first array element. - * - * The iterator's dereference operator will automatically wrap each element in a traced wrapper. - */ - iterator begin() const { return iterator(array_, 0, path_); } - - /*! - * \brief Get an iterator to the end of the array. - * - * The iterator's dereference operator will automatically wrap each element in a traced wrapper. - */ - iterator end() const { return iterator(array_, array_.size(), path_); } - - /*! - * \brief Returns true iff the wrapped array is empty. - */ - bool empty() const { return array_.empty(); } - - /*! - * \brief Get the size of the wrapped array. - */ - size_t size() const { return array_.size(); } - - private: - Array array_; - ObjectPath path_; -}; - -/*! - * \brief Traced wrapper for Optional objects. - */ -template -class TracedOptional { - public: - using WrappedT = typename detail::TracedObjectWrapperSelector::Type; - - /*! - * \brief Implicit conversion from the corresponding non-optional traced wrapper. - */ - TracedOptional(const WrappedT& value) // NOLINT(runtime/explicit) - : optional_(value.Get().defined() ? value.Get() : Optional(NullOpt)), - path_(value.GetPath()) {} - - // Don't use this direcly. For convenience, call MakeTraced() instead. - explicit TracedOptional(Optional optional, ObjectPath path) - : optional_(std::move(optional)), path_(std::move(path)) {} - - /*! - * \brief Access the wrapped optional object. - */ - const Optional& Get() const { return optional_; } - - /*! - * \brief Get the path of the wrapped optional object. - */ - const ObjectPath& GetPath() const { return path_; } - - /*! - * \brief Returns true iff the object is present. - */ - bool defined() const { return optional_.defined(); } - - /*! - * \brief Returns a non-optional traced wrapper, throws if defined() is false. - */ - WrappedT value() const { return WrappedT(optional_.value(), path_); } - - /*! - * \brief Same as defined(). - */ - explicit operator bool() const { return optional_.defined(); } - - private: - Optional optional_; - ObjectPath path_; -}; - -/*! - * \brief Traced wrapper for basic values (i.e. non-TVM objects) - */ -template -class TracedBasicValue { - public: - explicit TracedBasicValue(const T& value, ObjectPath path) - : value_(value), path_(std::move(path)) {} - - /*! - * \brief Access the wrapped value. - */ - const T& Get() const { return value_; } - - /*! - * \brief Get the path of the wrapped value. - */ - const ObjectPath& GetPath() const { return path_; } - - /*! - * \brief Transform the wrapped value without changing its path. - */ - template - typename detail::TracedObjectWrapperSelector::type>::Type - ApplyFunc(F&& f) const { - return MakeTraced(f(value_), path_); - } - - private: - T value_; - ObjectPath path_; -}; - -/*! - * \brief Wrap the given root object in an appropriate traced wrapper class. - */ -template -typename detail::TracedObjectWrapperSelector::Type MakeTraced(const RefT& object) { - using WrappedT = typename detail::TracedObjectWrapperSelector::Type; - return WrappedT(object, ObjectPath::Root()); -} - -/*! - * \brief Wrap the given object with the given path in an appropriate traced wrapper class. - */ -template -typename detail::TracedObjectWrapperSelector::Type MakeTraced(const RefT& object, - ObjectPath path) { - using WrappedT = typename detail::TracedObjectWrapperSelector::Type; - return WrappedT(object, std::move(path)); -} - -} // namespace tvm - -#endif // TVM_SCRIPT_PRINTER_TRACED_OBJECT_H_ diff --git a/include/tvm/script/printer/traced_object_functor.h b/include/tvm/script/printer/traced_object_functor.h deleted file mode 100644 index 8f72d139a5a5..000000000000 --- a/include/tvm/script/printer/traced_object_functor.h +++ /dev/null @@ -1,175 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -#ifndef TVM_SCRIPT_PRINTER_TRACED_OBJECT_FUNCTOR_H_ -#define TVM_SCRIPT_PRINTER_TRACED_OBJECT_FUNCTOR_H_ - -#include -#include -#include -#include - -#include -#include -#include -#include -#include - -namespace tvm { -namespace script { -namespace printer { - -/* - * This type alias and the following free functions are created to reduce the binary bloat - * from template and also hide implementation details from this header - */ -using DispatchTable = std::unordered_map>; - -/*! - * \brief Get function from dispatch table. - * \param dispatch_table The dispatch table. - * \param token The dispatch token. - * \param type_index The type index of the Object type to be dispatched. - * - * \return The dispatch function. - */ -const runtime::PackedFunc& GetDispatchFunction(const DispatchTable& dispatch_table, - const String& token, uint32_t type_index); - -/*! - * \brief Set function in dispatch table. - * \param dispatch_table The dispatch table. - * \param token The dispatch token. - * \param type_index The type index of the Object type to be dispatched. - * \param f The dispatch function. - */ -void SetDispatchFunction(DispatchTable* dispatch_table, const String& token, uint32_t type_index, - runtime::PackedFunc f); - -/*! - * \brief Remove function from dispatch table. - * \param dispatch_table The dispatch table. - * \param token The dispatch token. - * \param type_index The TVM object type index for the dispatch function to be removed. - */ -void RemoveDispatchFunction(DispatchTable* dispatch_table, const String& token, - uint32_t type_index); - -constexpr const char* kDefaultDispatchToken = ""; - -/*! - * \brief Dynamic dispatch functor based on TracedObject. - * - * This functor dispatches based on the type of object ref inside the input TracedObject, - * and the input dispatch token. - */ -template -class TracedObjectFunctor { - private: - using TSelf = TracedObjectFunctor; - - template - using IsDispatchFunction = - typename std::is_convertible, Args...)>>; - - public: - /*! - * \brief Call the dispatch function. - * \param token The dispatch token. - * \param traced_object The traced object. - * \param args Other args. - * - * \return The return value of the dispatch function - * - * If the TObjectRef isn't registered with the token, it will try to find - * dispatch function for TObjectRef with kDefaultDispatchToken. - */ - template - R operator()(const String& token, TracedObject traced_object, Args... args) const { - const runtime::PackedFunc& dispatch_function = - GetDispatchFunction(dispatch_table_, token, traced_object.Get()->type_index()); - return dispatch_function(traced_object.Get(), traced_object.GetPath(), args...); - } - - /*! - * \brief Set the dispatch function - * \param token The dispatch token. - * \param type_index The TVM object type index for this dispatch function. - * \param f The dispatch function. - * - * This takes a type-erased packed function as input. It should be used - * through FFI boundary, for example, registering dispatch function from Python. - */ - TSelf& set_dispatch(String token, uint32_t type_index, runtime::PackedFunc f) { - SetDispatchFunction(&dispatch_table_, token, type_index, std::move(f)); - return *this; - } - - /*! - * \brief Set the dispatch function - * \param token The dispatch token. - * \param f The dispatch function. - * - * The diaptch function should have signature `R(TracedObject, Args...)`. - */ - template ::value>> - TSelf& set_dispatch(String token, TCallable f) { - return set_dispatch( - token, // - TObjectRef::ContainerType::RuntimeTypeIndex(), // - runtime::TypedPackedFunc( - [f = std::move(f)](TObjectRef object, ObjectPath path, Args... args) -> R { - return f(MakeTraced(object, path), args...); - })); - } - /*! - * \brief Set the default dispatch function - * \param f The dispatch function. - * - * Default dispatch function will be used if there is no function registered - * with the requested dispatch token. - * - * Default dispatch function has an empty string as dispatch token. - */ - template ::value>> - TSelf& set_dispatch(TCallable&& f) { - return set_dispatch(kDefaultDispatchToken, std::forward(f)); - } - - /*! - * \brief Remove dispatch function - * \param token The dispatch token. - * \param type_index The TVM object type index for the dispatch function to be removed. - * - * This is useful when dispatch function comes from other language's runtime, and - * those function should be removed before that language runtime shuts down. - */ - void remove_dispatch(String token, uint32_t type_index) { - RemoveDispatchFunction(&dispatch_table_, token, type_index); - } - - private: - DispatchTable dispatch_table_; -}; - -} // namespace printer -} // namespace script -} // namespace tvm -#endif // TVM_SCRIPT_PRINTER_TRACED_OBJECT_FUNCTOR_H_ diff --git a/include/tvm/script/printer/var_table.h b/include/tvm/script/printer/var_table.h deleted file mode 100644 index 2cd9335213a3..000000000000 --- a/include/tvm/script/printer/var_table.h +++ /dev/null @@ -1,155 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -#ifndef TVM_SCRIPT_PRINTER_VAR_TABLE_H_ -#define TVM_SCRIPT_PRINTER_VAR_TABLE_H_ - -#include -#include -#include -#include -#include - -#include -#include - -namespace tvm { -namespace script { -namespace printer { - -/*! - * \brief Variable Table manages mapping from variable object to ExprDoc during - * the process of printing TVMScript. - * - * The value type of this map is ExprDoc rather than IdDoc or String. It's - * because variables can be implicitly defined. For example in TIR buffer (tir::Buffer), - * `buf->data` is a variable, while its representation in TVMScript should be an - * expression `x.data`, where `x` is the variable for the buffer itself. - */ -class VarTableNode : public Object { - public: - void VisitAttrs(AttrVisitor*) {} - - /*! - * \brief Define variable by name. - * \param obj The variable object. - * \param name_hint The hint for variable name. - * \param object_path The object_path for the returned ExprDoc. - * \param frame The frame that this variable is defined in. - * - * \return The id doc for this variable. - * - * This function will rename the variable to avoid name conflict with other variables - * in the table. - */ - IdDoc Define(const ObjectRef& obj, const String& name_hint, const ObjectPath& object_path, - const Frame& frame); - - /*! - * \brief Define variable by name. - * \param obj The variable object. - * \param name_hint The hint for variable name. - * \param frame The frame that this variable is defined in. - * - * \return The id doc for this variable. - * - * This is a shortcut version of `Define` which accepts a traced string. - */ - IdDoc Define(const ObjectRef& obj, const TracedObject& name_hint, const Frame& frame) { - return Define(obj, name_hint.Get(), name_hint.GetPath(), frame); - } - - using DocFactory = std::function; - - /*! - * \brief Define variable by doc factory. - * \param obj The variable object. - * \param doc_factory The function to return an ExprDoc object for this variable. - * \param frame The frame that this variable is defined in. - * - * This function is a special form of `Define`. Variable is mapped to ExprDoc rather - * than IdDoc. It's useful when a variable is implicitly defined without a name, like - * the buf->data in TIR, which should be mapped to `AttrDoc(IdDoc(""), "data")`. - * - * This function takes a DocFactory instead of Doc. It's because GetVarDoc needs to - * return a new Doc object every time it's called, as the returned doc will have - * different `soruce_path`. Currently there isn't a good way to deep copy a TVMObject - * so VarTable needs to call a factory function to get a freshly-constructed Doc object - * every time GetVarDoc is called. - */ - void DefineByDoc(const ObjectRef& obj, DocFactory doc_factory, const Frame& frame); - - /*! - * \brief Get the doc for variable. - * \param obj The variable object. - * \param object_path The object path for the variable. - * - * \return The doc for variable, if it exists in the table. Otherwise it returns NullOpt. - */ - Optional GetVarDoc(const ObjectRef& obj, const ObjectPath& object_path) const; - - /*! - * \brief Get the doc for variable. - * \param obj The traced variable object. - * - * \return The doc for variable, if it exists in the table. Otherwise it returns NullOpt. - */ - template - Optional GetVarDoc(const TracedObject obj) const { - return GetVarDoc(obj.Get(), obj.GetPath()); - } - - /*! - * \brief Check if a variable exists in the table. - * \param obj The variable object. - * - * \return a boolean for whether variable exists. - */ - bool IsVarDefined(const ObjectRef& obj) const; - - static constexpr const char* _type_key = "script.printer.VarTable"; - TVM_DECLARE_FINAL_OBJECT_INFO(VarTableNode, Object); - - private: - void RemoveVar(const ObjectRef& obj); - - struct VariableInfo { - DocFactory doc_factory; - Optional name; - }; - std::unordered_map obj2info; - std::unordered_set defined_names; -}; - -/*! - * \brief Reference type of VarTableNode. - */ -class VarTable : public ObjectRef { - public: - /*! - * \brief Create an empty VarTable. - */ - VarTable(); - TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(VarTable, ObjectRef, VarTableNode); -}; - -} // namespace printer -} // namespace script -} // namespace tvm - -#endif // TVM_SCRIPT_PRINTER_VAR_TABLE_H_ diff --git a/include/tvm/support/with.h b/include/tvm/support/with.h index 5959affafdb3..8333adc9e613 100644 --- a/include/tvm/support/with.h +++ b/include/tvm/support/with.h @@ -92,34 +92,5 @@ class With { ContextType ctx_; }; -/*! - * \brief A context type that delegates EnterWithScope and ExitWithScope - * to user-provided functions. - */ -class ContextManager { - public: - /*! - * \brief Constructor of ContextManager. - * \param f_enter The function to call when entering scope. If it's nullptr, do nothing when - * entering. - * \param f_exit The function to call when exiting scope. If it's nullptr, do nothing - * when exiting. - */ - template - explicit ContextManager(FEnter f_enter, FExit f_exit) : f_enter_(f_enter), f_exit_(f_exit) {} - - private: - void EnterWithScope() { - if (f_enter_) f_enter_(); - } - void ExitWithScope() { - if (f_exit_) f_exit_(); - } - std::function f_enter_; - std::function f_exit_; - template - friend class With; -}; - } // namespace tvm #endif // TVM_SUPPORT_WITH_H_ diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h index 9b48b0ccebd1..21bc7e7a5056 100644 --- a/include/tvm/tir/op.h +++ b/include/tvm/tir/op.h @@ -40,6 +40,9 @@ namespace tvm { +#define TVM_TIR_REGISTER_OP(OpName) \ + TVM_REGISTER_OP("tir." OpName).set_attr("TScriptPrinterName", OpName) + // Most common operators can be overloaded by argument type(PrimExpr). // So we put them under the root namespace. // diff --git a/include/tvm/tir/op_attr_types.h b/include/tvm/tir/op_attr_types.h index 2dc174f7d2a1..858d89c2d551 100644 --- a/include/tvm/tir/op_attr_types.h +++ b/include/tvm/tir/op_attr_types.h @@ -56,6 +56,11 @@ using FLowerIntrinsic = runtime::TypedPackedFunc; */ using FLegalize = runtime::TypedPackedFunc; +/*! + * \brief The operator's name in TVMScript printer + */ +using TScriptPrinterName = String; + /*! * \brief The effect type of the call. */ diff --git a/python/tvm/script/__init__.py b/python/tvm/script/__init__.py index 21bdfa6f1691..82bb698f2773 100644 --- a/python/tvm/script/__init__.py +++ b/python/tvm/script/__init__.py @@ -15,4 +15,7 @@ # specific language governing permissions and limitations # under the License. """TVM Script APIs of TVM Python Package""" -from .parser import ir, ir_module, parse as from_source, tir +from .parser import ir, ir_module +from .parser import parse as from_source +from .parser import tir +from .printer import script diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index 48b283447969..06a85fa34082 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -27,6 +27,7 @@ # isort: on import numpy as np # type: ignore + from tvm.ir import Range, Type from tvm.runtime import convert, ndarray from tvm.target import Target @@ -508,7 +509,9 @@ class axis: # pylint: disable=invalid-name @staticmethod def spatial( - dom: Union[Range, List[PrimExpr], Tuple[PrimExpr]], binding: PrimExpr, dtype: str = "int32" + dom: Union[Range, List[PrimExpr], Tuple[PrimExpr]], + binding: PrimExpr, + dtype: str = "int32", ) -> Var: """The spatial block axis defining function. @@ -534,7 +537,9 @@ def spatial( @staticmethod def reduce( - dom: Union[Range, List[PrimExpr], Tuple[PrimExpr]], binding: PrimExpr, dtype: str = "int32" + dom: Union[Range, List[PrimExpr], Tuple[PrimExpr]], + binding: PrimExpr, + dtype: str = "int32", ) -> Var: """The reduced block axis defining function. @@ -560,7 +565,9 @@ def reduce( @staticmethod def scan( - dom: Union[Range, List[PrimExpr], Tuple[PrimExpr]], binding: PrimExpr, dtype: str = "int32" + dom: Union[Range, List[PrimExpr], Tuple[PrimExpr]], + binding: PrimExpr, + dtype: str = "int32", ) -> Var: """The scanning block axis defining function. @@ -586,7 +593,9 @@ def scan( @staticmethod def opaque( - dom: Union[Range, List[PrimExpr], Tuple[PrimExpr]], binding: PrimExpr, dtype: str = "int32" + dom: Union[Range, List[PrimExpr], Tuple[PrimExpr]], + binding: PrimExpr, + dtype: str = "int32", ) -> Var: """The opaque block axis defining function. @@ -1534,34 +1543,41 @@ def target(target_config: Union[Dict, str]) -> Target: return Target(target_config) -def _op_wrapper(func): - @functools.wraps(func) - def wrapped(*args, **kwargs): - if "dtype" in kwargs: - kwargs.pop("dtype") - return func(*args, **kwargs) +class meta_var: # pylint: disable=invalid-name + """A meta variable used in TVMScript metaprogramming. It means that the value of the variable + does not appear in the final TIR, but only stays in the parser. - return wrapped + Parameters + ---------- + value: Any + The meta variable. + """ + def __init__(self, value: Any) -> None: + self.value = value -def _dtype_forward(func): + def __iter__(self): + def f(): + for i in self.value: + yield meta_var(i) + + return f() + + +# pylint: disable=invalid-name + + +def _op_wrapper(func): @functools.wraps(func) def wrapped(*args, **kwargs): if "dtype" in kwargs: - args = (kwargs.pop("dtype"),) + args + kwargs.pop("dtype") return func(*args, **kwargs) return wrapped -# pylint: disable=invalid-name - -broadcast = Broadcast -ramp = Ramp - -buffer_var = ptr abs = _op_wrapper(_tir_op.abs) # pylint: disable=redefined-builtin -fabs = abs acos = _op_wrapper(_tir_op.acos) acosh = _op_wrapper(_tir_op.acosh) address_of = _op_wrapper(_tir_op.address_of) @@ -1607,7 +1623,6 @@ def wrapped(*args, **kwargs): q_multiply_shift = _op_wrapper(_tir_op.q_multiply_shift) q_multiply_shift_per_axis = _op_wrapper(_tir_op.q_multiply_shift_per_axis) ret = _op_wrapper(_tir_op.ret) -reinterpret = _dtype_forward(_tir_op.reinterpret) round = _op_wrapper(_tir_op.round) # pylint: disable=redefined-builtin rsqrt = _op_wrapper(_tir_op.rsqrt) shift_left = _op_wrapper(_tir_op.shift_left) @@ -1631,11 +1646,6 @@ def wrapped(*args, **kwargs): call_cpacked = _op_wrapper(_tir_op.call_cpacked) call_packed_lowered = _op_wrapper(_tir_op.call_packed_lowered) call_cpacked_lowered = _op_wrapper(_tir_op.call_cpacked_lowered) -call_extern = _dtype_forward(_tir_op.call_extern) -call_intrin = _dtype_forward(_tir_op.call_intrin) -call_llvm_intrin = _dtype_forward(_tir_op.call_llvm_intrin) -call_llvm_pure_intrin = _dtype_forward(_tir_op.call_llvm_pure_intrin) -call_pure_extern = _dtype_forward(_tir_op.call_pure_extern) tvm_tuple = _op_wrapper(_tir_op.tvm_tuple) tvm_struct_set = _op_wrapper(_tir_op.tvm_struct_set) tvm_struct_get = _tir_op.tvm_struct_get @@ -1645,48 +1655,51 @@ def wrapped(*args, **kwargs): tvm_bmma_sync = _op_wrapper(_tir_op.tvm_bmma_sync) tvm_fill_fragment = _op_wrapper(_tir_op.tvm_fill_fragment) tvm_store_matrix_sync = _op_wrapper(_tir_op.tvm_store_matrix_sync) -ptx_mma = _dtype_forward(_tir_op.ptx_mma) -ptx_mma_sp = _dtype_forward(_tir_op.ptx_mma_sp) -ptx_ldmatrix = _dtype_forward(_tir_op.ptx_ldmatrix) -ptx_cp_async = _dtype_forward(_tir_op.ptx_cp_async) ptx_wait_group = _op_wrapper(_tir_op.ptx_wait_group) ptx_commit_group = _op_wrapper(_tir_op.ptx_commit_group) -mma_store = _dtype_forward(_tir_op.mma_store) -mma_fill = _dtype_forward(_tir_op.mma_fill) -vectorlow = _dtype_forward(_tir_op.vectorlow) -vectorhigh = _dtype_forward(_tir_op.vectorhigh) -vectorcombine = _dtype_forward(_tir_op.vectorcombine) assume = _op_wrapper(_tir_op.assume) undef = _op_wrapper(_tir_op.undef) -tvm_call_packed = call_packed -tvm_call_cpacked = call_cpacked -tvm_call_packed_lowered = call_packed_lowered -tvm_call_cpacked_lowered = call_cpacked_lowered TVMBackendAllocWorkspace = _op_wrapper(_tir_op.TVMBackendAllocWorkspace) TVMBackendFreeWorkspace = _op_wrapper(_tir_op.TVMBackendFreeWorkspace) start_profile_intrinsic = _op_wrapper(_tir_op.start_profile_intrinsic) end_profile_intrinsic = _op_wrapper(_tir_op.end_profile_intrinsic) -class meta_var: - """A meta variable used in TVMScript metaprogramming. It means that the value of the variable - does not appear in the final TIR, but only stays in the parser. +def _dtype_forward(func): + @functools.wraps(func) + def wrapped(*args, **kwargs): + if "dtype" in kwargs: + args = (kwargs.pop("dtype"),) + args + return func(*args, **kwargs) - Parameters - ---------- - value: Any - The meta variable. - """ + return wrapped - def __init__(self, value: Any) -> None: - self.value = value - def __iter__(self): - def f(): - for i in self.value: - yield meta_var(i) +reinterpret = _dtype_forward(_tir_op.reinterpret) +call_extern = _dtype_forward(_tir_op.call_extern) +call_intrin = _dtype_forward(_tir_op.call_intrin) +call_llvm_intrin = _dtype_forward(_tir_op.call_llvm_intrin) +call_llvm_pure_intrin = _dtype_forward(_tir_op.call_llvm_pure_intrin) +call_pure_extern = _dtype_forward(_tir_op.call_pure_extern) +ptx_mma = _dtype_forward(_tir_op.ptx_mma) +ptx_mma_sp = _dtype_forward(_tir_op.ptx_mma_sp) +ptx_ldmatrix = _dtype_forward(_tir_op.ptx_ldmatrix) +ptx_cp_async = _dtype_forward(_tir_op.ptx_cp_async) +mma_store = _dtype_forward(_tir_op.mma_store) +mma_fill = _dtype_forward(_tir_op.mma_fill) +vectorlow = _dtype_forward(_tir_op.vectorlow) +vectorhigh = _dtype_forward(_tir_op.vectorhigh) +vectorcombine = _dtype_forward(_tir_op.vectorcombine) - return f() + +broadcast = Broadcast +ramp = Ramp +buffer_var = ptr +fabs = abs +tvm_call_packed = call_packed +tvm_call_cpacked = call_cpacked +tvm_call_packed_lowered = call_packed_lowered +tvm_call_cpacked_lowered = call_cpacked_lowered # pylint: enable=invalid-name diff --git a/python/tvm/script/printer/__init__.py b/python/tvm/script/printer/__init__.py index d49614db0f21..25ea619a410c 100644 --- a/python/tvm/script/printer/__init__.py +++ b/python/tvm/script/printer/__init__.py @@ -16,12 +16,7 @@ # under the License. """ TVMScript Unified Printer - This package provides a set of APIs to print supported TVM IR into TVMScript in a roundtrippable way. - -https://github.com/apache/tvm-rfcs/blob/main/rfcs/0074-tvmscript-unified-printer.md """ - -from . import _ffi_api -from .entry import script +from .printer import script diff --git a/python/tvm/script/printer/entry.py b/python/tvm/script/printer/entry.py deleted file mode 100644 index c015702af09b..000000000000 --- a/python/tvm/script/printer/entry.py +++ /dev/null @@ -1,71 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -This file contains the entry point of TVMScript Unified Printer. -""" - -from typing import Dict, Optional - -from tvm.runtime import Object, ObjectPath - -from . import _ffi_api - - -def script( # pylint: disable=too-many-arguments - root_node: Object, - ir_name: str, - ir_prefix: Dict[str, str], - indent_spaces: int = 4, - print_line_numbers: bool = False, - num_context_lines: int = -1, - path_to_underline: Optional[ObjectPath] = None, -) -> str: - """ - Print IR graph as TVMScript code - - Parameters - ---------- - root_node : Object - The root node to print. - ir_name : str - The dispatch token of the target IR, e.g., "tir", "relax". - ir_prefix : Dict[str, str] - The symbol name for TVMScript IR namespaces. For example, - {"tir": "T"}. - indent_spaces : int - The number of indent spaces to use in the output - print_line_numbers: bool - Whether to print line numbers - num_context_lines : Optional[int] - Number of context lines to print around the underlined text - path_to_underline : Optional[ObjectPath] - Object path to be underlined - - Returns - ------- - script : str - The TVMScript code of the root_node - """ - return _ffi_api.Script( # type: ignore # pylint: disable=no-member - root_node, - ir_name, - ir_prefix, - indent_spaces, - print_line_numbers, - num_context_lines, - path_to_underline, - ) diff --git a/python/tvm/script/printer/frame.py b/python/tvm/script/printer/frame.py deleted file mode 100644 index c967382b8b5d..000000000000 --- a/python/tvm/script/printer/frame.py +++ /dev/null @@ -1,81 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -Frame is the core data structure for semantic information when printing -IR graph into TVMScript code. -""" - -from typing import Callable, Sequence - -from tvm._ffi import register_object -from tvm.runtime import Object -from tvm.script.printer.doc import StmtDoc - -from . import _ffi_api - - -class Frame(Object): - """ - Frame is the core data structure for semantic information - when printing IR graph into TVMScript code. - - Frame base class manages a list of callbacks to be executed - when frame goes out of scope. - """ - - def add_exit_callback(self, callback: Callable[[], None]) -> None: - """ - Adds a callback function to be executed when frame goes out of scope. - - Parameters - ---------- - callback : Callable[[], None] - The callback function. - """ - _ffi_api.FrameAddExitCallback(self, callback) # type: ignore # pylint: disable=no-member - - def __enter__(self): - _ffi_api.FrameEnterWithScope(self) # type: ignore # pylint: disable=no-member - return self - - def __exit__(self, *exception_info): - _ffi_api.FrameExitWithScope(self) # type: ignore # pylint: disable=no-member - - -@register_object("script.printer.MetadataFrame") -class MetadataFrame(Frame): - """ - MetadataFrame contains information like contant parameter array. - """ - - metadata: Sequence[Object] - - def __init__(self): - self.__init_handle_by_constructor__(_ffi_api.MetadataFrame) # type: ignore # pylint: disable=no-member - - -@register_object("script.printer.VarDefFrame") -class VarDefFrame(Frame): - """ - VarDefFrame contains information about the free variables that needs to - be defined at the beginning of the printed snippet. - """ - - stmts: Sequence[StmtDoc] - - def __init__(self): - self.__init_handle_by_constructor__(_ffi_api.VarDefFrame) # type: ignore # pylint: disable=no-member diff --git a/python/tvm/script/printer/ir_docsifier.py b/python/tvm/script/printer/ir_docsifier.py deleted file mode 100644 index c5ba8a498b1e..000000000000 --- a/python/tvm/script/printer/ir_docsifier.py +++ /dev/null @@ -1,245 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -IRDocsifier is the top-level interface in the process of transforming -IR graph into Doc tree, during printing IR graph as TVMScript code. -""" - -import atexit -from contextlib import ExitStack, contextmanager -from typing import Callable, Dict, Generator, Mapping, Optional, Sequence, Set, Tuple, Type, TypeVar - -from tvm._ffi import get_object_type_index, register_object -from tvm.runtime import Object, ObjectPath - -from . import _ffi_api -from .doc import Doc -from .frame import Frame -from .var_table import VarTable - -_REGISTERED_TYPES: Set[Tuple[str, int]] = set() # {(dispatch_token, type_index)} - - -def _cleanup_dispatch_function(): - for dispatch_token, type_index in _REGISTERED_TYPES: - _ffi_api.IRDocsifierRemoveDispatch(dispatch_token, type_index) # type: ignore # pylint: disable=no-member - - -_CLEANUP_REGISTERED = False - - -def _ensure_cleanup_function_registered(): - """ - Add a cleanup function to be called on interpreter termination, - to remove all dispatch functions registered on the Python side. - - Without cleaning up those dispatch functions, program will segfault - on termination. It's because dispatch functions are referenced from the - static memory of libtvm, thus they will be cleaned up at the very end, - making calls to Py_DecRef after Python interpreter terminates. - """ - global _CLEANUP_REGISTERED # pylint: disable=global-statement - - if not _CLEANUP_REGISTERED: - atexit.register(_cleanup_dispatch_function) - _CLEANUP_REGISTERED = True - - -@register_object("script.printer.RootNodeContainer") -class RootNodeContainer(Object): - """ - A wrapper object to provide injection point for printer of each IR. - - This class shouldn't be used directly. `IRDocsifier.set_root_dispatch` - should be used instead. - """ - - root_node: Object - - def __init__(self, root_node: Object): - self.__init_handle_by_constructor__(_ffi_api.RootNodeContainer, root_node) # type: ignore # pylint: disable=no-member - - -@register_object("script.printer.IRDocsifier") -class IRDocsifier(Object): - """ - IRDocsifier is the top-level interface in the IR->Doc process. - - It provides methods to convert IR node object to Doc, operate on Frame - objects and change dispatch tokens. - """ - - ir_prefix: Mapping[str, str] - vars: VarTable - frames: Sequence[Frame] - dispatch_tokens: Sequence[str] - - def __init__(self, ir_prefix: Dict[str, str]): - """ - Create a new IRDocsifier. - - Parameters - ---------- - ir_prefix : Dict[str, str] - The ir prefix to use. Key is the IR dispatch token and - value is the name of identifier for this IR's namespace in TVMScript. - """ - self.__init_handle_by_constructor__(_ffi_api.IRDocsifier, ir_prefix) # type: ignore # pylint: disable=no-member - - _TObject = TypeVar("_TObject", bound=Object) - - @classmethod - def set_dispatch( - cls, - node_type: Type[_TObject], - dispatch_function: Callable[[_TObject, ObjectPath, "IRDocsifier"], Doc], - dispatch_token: str = "", - ) -> None: - """ - Set the dispatch function to transform a particular IR node type to Doc - - Parameters - ---------- - node_type : Type[_TObject] - The type of object to dispatch on. - dispatch_function : Callable[[_TObject, ObjectPath, "IRDocsifier"], Doc] - The dispatch function. It's called to transform IR node object to Doc. - dispatch_token : str - Function will only be called when this dispatch_token is the same as the one - on the top of IRDocsifier's dispatch_tokens stack. An empty dispatch token - means registering as default dispatch function, which will be called when - there is no dispatch function registered with the current dispatch token. - """ - type_index = get_object_type_index(node_type) - if type_index is None: - raise TypeError(f"{type(node_type)} is not a registered TVM object type.") - - _ensure_cleanup_function_registered() - _ffi_api.IRDocsifierSetDispatch( # type: ignore # pylint: disable=no-member - dispatch_token, type_index, dispatch_function - ) - _REGISTERED_TYPES.add((dispatch_token, type_index)) - - @classmethod - def set_root_dispatch( - cls, dispatch_token: str, root_dispatch_function: Callable[[Object, "IRDocsifier"], Doc] - ) -> None: - """ - Set the root dispatch function for an IR. - - The root dispatch function will be called with the root node of an IR graph - that's being transformed to Doc. This provides an injection point for - each IR's printer implemention to add specialized logic, for example, - pushing a special Frame to the IRDocsifier before doing actual IR->Doc - transformation. - - The simplest root dispatch function is - ``` - def f(obj, ir_docsifier) - return ir_docsifier.as_doc(obj, ObjectPath.root()) - ``` - - Parameters - ---------- - root_dispatch_function : Callable[[_TObject, "IRDocsifier"], Doc] - The root dispatch function. It's called with the root node to be printed. - dispatch_token : str - The dispatch token of the IR that root_dispatch_funnction applies to. - """ - - def dispatch_function(obj: RootNodeContainer, _, ir_docsifier): - return root_dispatch_function(obj.root_node, ir_docsifier) - - cls.set_dispatch(RootNodeContainer, dispatch_function, dispatch_token) - - def as_doc(self, obj: Object, object_path: ObjectPath) -> Doc: - """ - Transform the input object into Doc. - - Parameters - ---------- - obj : Object - The IR node object. - object_path : ObjectPath - The object path of this object. It's used for locating diagnostic message. - - Returns - ------- - doc : Doc - The doc for this object. - """ - return _ffi_api.IRDocsifierAsDoc(self, obj, object_path) # type: ignore # pylint: disable=no-member - - def get_frame(self, frame_type: Type[Frame]) -> Optional[Frame]: - """ - Get the top frame with type `frame_type`. - - Parameters - ---------- - frame_type : Type[Frame] - The target frame type. - - Returns - ------- - frame : Optional[Frame] - The frame if found, otherwise None. - """ - for i in range(len(self.frames) - 1, -1, -1): - if isinstance(self.frames[i], frame_type): - return self.frames[i] - return None - - @contextmanager - def dispatch_token(self, token: str): - """ - Push a new dispatch token to the stack. - - Parameters - ---------- - token : str - The token to push. - - Returns - ------- - A context manager that pops this dispatch token when exits. - """ - with ExitStack() as stack: - _ffi_api.IRDocsifierPushDispatchToken(self, token) # type: ignore # pylint: disable=no-member - stack.callback(_ffi_api.IRDocsifierPopDispatchToken, self) # type: ignore # pylint: disable=no-member - yield - - _TFrame = TypeVar("_TFrame", bound=Frame) - - @contextmanager - def frame(self, frame: _TFrame) -> Generator[_TFrame, None, None]: - """ - Push a new frame to the stack. - - Parameters - ---------- - frame : Frame - The frame to push. - - Returns - ------- - A context manager that pops this frame when exits. - """ - with ExitStack() as stack: - stack.enter_context(frame) - _ffi_api.IRDocsifierPushFrame(self, frame) # type: ignore # pylint: disable=no-member - stack.callback(_ffi_api.IRDocsifierPopFrame, self) # type: ignore # pylint: disable=no-member - yield frame diff --git a/python/tvm/script/printer/printer.py b/python/tvm/script/printer/printer.py new file mode 100644 index 000000000000..120ef03f57d7 --- /dev/null +++ b/python/tvm/script/printer/printer.py @@ -0,0 +1,64 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""The printer interface""" + +from typing import Mapping, Optional + +from tvm.runtime.object_path import ObjectPath + +from . import _ffi_api + + +def script( + obj, + ir_prefix: Optional[Mapping[str, str]] = None, + indent_space: int = 4, + print_line_number: bool = False, + num_context_lines: int = -1, + path_to_underline: Optional[ObjectPath] = None, +): + """Print a TVM IR as a TVMScript text format. + + Parameters + ---------- + obj : object + An TVM object representing TVM IR + ir_prefix : Optional[Mapping[str, str]] + A mapping from IR type to the prefix of the script. + Default to {"ir": "I", "tir": T} + indent_space : int = 4 + The number of spaces to indent + print_line_number : bool = False + Whether to print line number + num_context_lines : int = -1 + The number of context lines to print. -1 means all lines. + path_to_underline : Optional[ObjectPath] + The path to underline in the script. + + Returns + ------- + script : str + The TVMScript text format + """ + if ir_prefix is None: + ir_prefix = { + "ir": "I", + "tir": "T", + } + return _ffi_api.Script( # type: ignore # pylint: disable=no-member + obj, ir_prefix, indent_space, print_line_number, num_context_lines, path_to_underline + ) diff --git a/python/tvm/script/printer/var_table.py b/python/tvm/script/printer/var_table.py deleted file mode 100644 index ea1fa41b3210..000000000000 --- a/python/tvm/script/printer/var_table.py +++ /dev/null @@ -1,118 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""Functions to print doc into text format""" - -from typing import Callable, Optional - -from tvm._ffi import register_object -from tvm.runtime import Object, ObjectPath - -from . import _ffi_api -from .doc import ExprDoc, IdDoc -from .frame import Frame - - -@register_object("script.printer.VarTable") -class VarTable(Object): - """ - Variable Table manages mapping from variable object to ExprDoc during - the process of printing TVMScript. - """ - - def __init__(self): - """ - Create an empty VarTable. - """ - self.__init_handle_by_constructor__(_ffi_api.VarTable) # type: ignore # pylint: disable=no-member - - def define(self, obj: Object, name_hint: str, object_path: ObjectPath, frame: Frame) -> IdDoc: - """ - Define a variable by name. - - Parameters - ---------- - obj : Object - The variable object. - name_hint : str - The hint for variable name. - object_path : ObjectPath - The object path to be associated with the returned ExprDoc. - frame : Frame - Then frame that this variable is defined in. - - Returns - ------- - doc : IdDoc - The doc for this variable. - """ - return _ffi_api.VarTableDefine(self, obj, name_hint, object_path, frame) # type: ignore # pylint: disable=no-member - - def define_by_doc(self, obj: Object, doc_factory: Callable[[], ExprDoc], frame: Frame) -> None: - """ - Define a variable by ExprDoc. - - Parameters - ---------- - obj : Object - The variable object. - doc_factory : Callable[[], ExprDoc] - The hint for variable name. - frame : Frame - Then frame that this variable is defined in. - - Returns - ------- - None - """ - _ffi_api.VarTableDefineByDoc(self, obj, doc_factory, frame) # type: ignore # pylint: disable=no-member - - def get_var_doc(self, obj: Object, object_path: ObjectPath) -> Optional[ExprDoc]: - """ - Get the doc for a variable. - - Parameters - ---------- - obj : Object - The variable object. - object_path : ObjectPath - The object path to be associated with the returned ExprDoc. - - Returns - ------- - doc : ExprDoc - The doc for this variable. - """ - return _ffi_api.VarTableGetVarDoc(self, obj, object_path) # type: ignore # pylint: disable=no-member - - def is_var_defined(self, obj: Object) -> bool: - """ - Check whether a variable is defined. - - Parameters - ---------- - obj : Object - The variable object. - - Returns - ------- - is_defined : bool - Whether the variable is defined. - """ - return _ffi_api.VarTableIsVarDefined(self, obj) # type: ignore # pylint: disable=no-member - - def __contains__(self, obj: Object) -> bool: - return self.is_var_defined(obj) diff --git a/src/script/printer/doc.cc b/src/script/printer/doc.cc index 1ca7ced8e8a7..f41b40c92cc9 100644 --- a/src/script/printer/doc.cc +++ b/src/script/printer/doc.cc @@ -27,18 +27,12 @@ namespace printer { ExprDoc ExprDocNode::Attr(String attr) const { return AttrAccessDoc(GetRef(this), attr); } -ExprDoc ExprDocNode::Attr(TracedObject attr) const { - auto doc = AttrAccessDoc(GetRef(this), attr.Get()); - doc->source_paths.push_back(attr.GetPath()); - return std::move(doc); -} - ExprDoc ExprDocNode::operator[](Array indices) const { return IndexDoc(GetRef(this), indices); } ExprDoc ExprDocNode::Call(Array args) const { - return CallDoc(GetRef(this), args, {}, {}); + return CallDoc(GetRef(this), args, Array(), Array()); } ExprDoc ExprDocNode::Call(Array args, Array kwargs_keys, @@ -258,7 +252,7 @@ TVM_REGISTER_GLOBAL("script.printer.StmtBlockDoc").set_body_typed([](Array(LiteralDoc::None); TVM_REGISTER_GLOBAL("script.printer.LiteralDocInt") - .set_body_typed(LiteralDoc::Int); + .set_body_typed(LiteralDoc::Int); TVM_REGISTER_GLOBAL("script.printer.LiteralDocBoolean") .set_body_typed(LiteralDoc::Boolean); TVM_REGISTER_GLOBAL("script.printer.LiteralDocFloat") diff --git a/src/script/printer/base_doc_printer.cc b/src/script/printer/doc_printer/base_doc_printer.cc similarity index 100% rename from src/script/printer/base_doc_printer.cc rename to src/script/printer/doc_printer/base_doc_printer.cc diff --git a/src/script/printer/base_doc_printer.h b/src/script/printer/doc_printer/base_doc_printer.h similarity index 97% rename from src/script/printer/base_doc_printer.h rename to src/script/printer/doc_printer/base_doc_printer.h index f3fb24d946e1..db1d733d96ad 100644 --- a/src/script/printer/base_doc_printer.h +++ b/src/script/printer/doc_printer/base_doc_printer.h @@ -16,11 +16,10 @@ * specific language governing permissions and limitations * under the License. */ -#ifndef TVM_SCRIPT_PRINTER_BASE_DOC_PRINTER_H_ -#define TVM_SCRIPT_PRINTER_BASE_DOC_PRINTER_H_ +#ifndef TVM_SCRIPT_PRINTER_DOC_PRINTER_BASE_DOC_PRINTER_H_ +#define TVM_SCRIPT_PRINTER_DOC_PRINTER_BASE_DOC_PRINTER_H_ #include -#include #include #include @@ -287,4 +286,4 @@ class DocPrinter { } // namespace script } // namespace tvm -#endif // TVM_SCRIPT_PRINTER_BASE_DOC_PRINTER_H_ +#endif // TVM_SCRIPT_PRINTER_DOC_PRINTER_BASE_DOC_PRINTER_H_ diff --git a/src/script/printer/python_doc_printer.cc b/src/script/printer/doc_printer/python_doc_printer.cc similarity index 98% rename from src/script/printer/python_doc_printer.cc rename to src/script/printer/doc_printer/python_doc_printer.cc index 753f907c423c..6851baf63866 100644 --- a/src/script/printer/python_doc_printer.cc +++ b/src/script/printer/doc_printer/python_doc_printer.cc @@ -21,10 +21,11 @@ #include #include +#include #include -#include "../../support/str_escape.h" -#include "../../support/utils.h" +#include "../../../support/str_escape.h" +#include "../../../support/utils.h" #include "./base_doc_printer.h" namespace tvm { @@ -294,7 +295,11 @@ void PythonDocPrinter::PrintTypedDoc(const LiteralDoc& doc) { } else if (const auto* float_imm = value.as()) { // TODO(yelite): Make float number printing roundtrippable output_.precision(17); - output_ << float_imm->value; + if (std::isinf(float_imm->value) || std::isnan(float_imm->value)) { + output_ << '"' << float_imm->value << '"'; + } else { + output_ << float_imm->value; + } } else if (const auto* string_obj = value.as()) { output_ << "\"" << support::StrEscape(string_obj->data, string_obj->size) << "\""; } else { diff --git a/src/script/printer/frame.cc b/src/script/printer/frame.cc deleted file mode 100644 index b342c7c886c7..000000000000 --- a/src/script/printer/frame.cc +++ /dev/null @@ -1,50 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -#include -#include - -namespace tvm { -namespace script { -namespace printer { - -MetadataFrame::MetadataFrame() : MetadataFrame(make_object()) {} - -VarDefFrame::VarDefFrame() : VarDefFrame(make_object()) {} - -TVM_REGISTER_NODE_TYPE(FrameNode); -TVM_REGISTER_GLOBAL("script.printer.FrameAddExitCallback") - .set_body_typed([](Frame frame, runtime::TypedPackedFunc callback) { - frame->AddExitCallback(callback); - }); -TVM_REGISTER_GLOBAL("script.printer.FrameEnterWithScope") - .set_body_method(&FrameNode::EnterWithScope); -TVM_REGISTER_GLOBAL("script.printer.FrameExitWithScope") - .set_body_method(&FrameNode::ExitWithScope); - -TVM_REGISTER_NODE_TYPE(MetadataFrameNode); -TVM_REGISTER_GLOBAL("script.printer.MetadataFrame").set_body_typed([]() { - return MetadataFrame(); -}); - -TVM_REGISTER_NODE_TYPE(VarDefFrameNode); -TVM_REGISTER_GLOBAL("script.printer.VarDefFrame").set_body_typed([]() { return VarDefFrame(); }); - -} // namespace printer -} // namespace script -} // namespace tvm diff --git a/src/script/printer/ir/ir.cc b/src/script/printer/ir/ir.cc new file mode 100644 index 000000000000..c4ecf92e9116 --- /dev/null +++ b/src/script/printer/ir/ir.cc @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "./utils.h" + +namespace tvm { +namespace script { +namespace printer { + +TVM_REGISTER_NODE_TYPE(IRFrameNode); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("", [](IRModule mod, ObjectPath p, IRDocsifier d) -> Doc { + std::vector> functions{mod->functions.begin(), + mod->functions.end()}; + // print "main" first + std::sort(functions.begin(), functions.end(), [](const auto& lhs, const auto& rhs) { + String lhs_name = lhs.first->name_hint; + String rhs_name = rhs.first->name_hint; + if (lhs_name == "main") { + lhs_name = ""; + } + if (rhs_name == "main") { + rhs_name = ""; + } + return lhs_name < rhs_name; + }); + ICHECK(!d->mod.defined()); + d->mod = mod; + { + With f(d); + (*f)->AddDispatchToken(d, "ir"); + for (const auto& kv : functions) { + GlobalVar gv = kv.first; + BaseFunc func = kv.second; + (*f)->stmts.push_back(d->AsDoc(func, p->Attr("functions")->MapValue(gv))); + } + return ClassDoc(IdDoc("Module"), {IR(d)}, (*f)->stmts); + } + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("", [](DictAttrs attrs, ObjectPath p, IRDocsifier d) -> Doc { + return d->AsDoc(attrs->dict, p->Attr("dict")); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("", [](GlobalVar gv, ObjectPath p, IRDocsifier d) -> Doc { + return IdDoc("GlobalVar")->Call({LiteralDoc::Str(gv->name_hint)}); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("", [](Op op, ObjectPath p, IRDocsifier d) -> Doc { + return IdDoc("Op")->Call({LiteralDoc::Str(op->name)}); + }); + +} // namespace printer +} // namespace script +} // namespace tvm diff --git a/src/script/printer/ir/misc.cc b/src/script/printer/ir/misc.cc new file mode 100644 index 000000000000..bd2792167194 --- /dev/null +++ b/src/script/printer/ir/misc.cc @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "./utils.h" + +namespace tvm { +namespace script { +namespace printer { + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("", [](String s, ObjectPath p, IRDocsifier d) -> Doc { + return LiteralDoc::Str(s); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch>( // + "", [](Array array, ObjectPath p, IRDocsifier d) -> Doc { + int n = array.size(); + Array results; + results.reserve(n); + for (int i = 0; i < n; ++i) { + results.push_back(d->AsDoc(array[i], p->ArrayIndex(i))); + } + return ListDoc(results); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch>( // + "", [](Map dict, ObjectPath p, IRDocsifier d) -> Doc { + using POO = std::pair; + std::vector items{dict.begin(), dict.end()}; + bool is_str_map = true; + for (const auto& kv : items) { + if (!kv.first.as()) { + is_str_map = false; + break; + } + } + if (is_str_map) { + std::sort(items.begin(), items.end(), [](const POO& lhs, const POO& rhs) { + return Downcast(lhs.first) < Downcast(rhs.first); + }); + } else { + std::sort(items.begin(), items.end(), [](const POO& lhs, const POO& rhs) { + return lhs.first.get() < rhs.first.get(); + }); + } + int n = dict.size(); + Array ks; + Array vs; + ks.reserve(n); + vs.reserve(n); + for (int i = 0; i < n; ++i) { + ks.push_back(d->AsDoc(items[i].first, p->MissingMapEntry())); + vs.push_back(d->AsDoc(items[i].second, p->MapValue(items[i].first))); + } + return DictDoc(ks, vs); + }); + +} // namespace printer +} // namespace script +} // namespace tvm diff --git a/src/script/printer/ir/utils.h b/src/script/printer/ir/utils.h new file mode 100644 index 000000000000..4065b895c1bb --- /dev/null +++ b/src/script/printer/ir/utils.h @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_SCRIPT_PRINTER_IR_UTILS_H_ +#define TVM_SCRIPT_PRINTER_IR_UTILS_H_ + +#include +#include +#include +#include +#include +#include + +#include + +namespace tvm { +namespace script { +namespace printer { + +inline ExprDoc IR(const IRDocsifier& d) { return IdDoc("tvm")->Attr("script"); } + +class IRFrameNode : public FrameNode { + public: + void VisitAttrs(AttrVisitor* v) { FrameNode::VisitAttrs(v); } + + static constexpr const char* _type_key = "script.printer.IRFrame"; + TVM_DECLARE_FINAL_OBJECT_INFO(IRFrameNode, FrameNode); +}; + +class IRFrame : public Frame { + public: + explicit IRFrame(const IRDocsifier& d) { + ObjectPtr n = make_object(); + n->stmts.clear(); + n->d = d.get(); + data_ = std::move(n); + } + + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(IRFrame, Frame, IRFrameNode); +}; + +} // namespace printer +} // namespace script +} // namespace tvm + +#endif // TVM_SCRIPT_PRINTER_IR_UTILS_H_ diff --git a/src/script/printer/ir_docsifier.cc b/src/script/printer/ir_docsifier.cc index 7f032ec50269..8584f360312f 100644 --- a/src/script/printer/ir_docsifier.cc +++ b/src/script/printer/ir_docsifier.cc @@ -20,21 +20,136 @@ #include #include #include -#include -#include namespace tvm { namespace script { namespace printer { -Doc IRDocsifierNode::AsDocImpl(const TracedObject& obj) const { - return IRDocsifier::vtable()(dispatch_tokens.back(), obj, GetRef(this)); +String GenerateUniqueName(std::string name_hint, std::unordered_set* defined_names) { + for (char& c : name_hint) { + if (c != 'c' && !std::isalnum(c)) { + c = '_'; + } + } + std::string name = name_hint; + for (int i = 1; !defined_names->insert(name).second; ++i) { + name = name_hint + "_" + std::to_string(i); + } + return name; +} + +IdDoc IRDocsifierNode::Define(const ObjectRef& obj, const Frame& frame, const String& name_hint) { + String name = GenerateUniqueName(name_hint, &this->defined_names); + DocCreator doc_factory = [name]() { return IdDoc(name); }; + auto result = obj2info.insert({obj, VariableInfo{std::move(doc_factory), name}}); + ICHECK(result.second) << "Duplicated object: " << obj; + IdDoc def_doc(name); + frame->AddExitCallback([this, obj]() { this->RemoveVar(obj); }); + return def_doc; +} + +void IRDocsifierNode::Define(const ObjectRef& obj, const Frame& frame, DocCreator doc_factory) { + ICHECK(obj2info.find(obj) == obj2info.end()) << "Duplicated object: " << obj; + ICHECK(!doc_factory()->IsInstance()) + << "IRDocsifierNode::Define cannot be used for variable that's mapped to IdDoc."; + obj2info.insert({obj, VariableInfo{std::move(doc_factory), NullOpt}}); + frame->AddExitCallback([this, obj]() { this->RemoveVar(obj); }); +} + +Optional IRDocsifierNode::GetVarDoc(const ObjectRef& obj) const { + auto it = obj2info.find(obj); + if (it == obj2info.end()) { + return NullOpt; + } + return it->second.creator(); +} + +bool IRDocsifierNode::IsVarDefined(const ObjectRef& obj) const { return obj2info.count(obj); } + +void IRDocsifierNode::RemoveVar(const ObjectRef& obj) { + auto it = obj2info.find(obj); + ICHECK(it != obj2info.end()) << "No such object: " << obj; + if (it->second.name.defined()) { + defined_names.erase(it->second.name.value()); + } + obj2info.erase(it); +} + +void IRDocsifierNode::SetCommonPrefix(const ObjectRef& root, + runtime::TypedPackedFunc is_var) { + class Visitor : public AttrVisitor { + public: + inline void operator()(ObjectRef obj) { Visit("", &obj); } + + private: + void Visit(const char* key, double* value) final {} + void Visit(const char* key, int64_t* value) final {} + void Visit(const char* key, uint64_t* value) final {} + void Visit(const char* key, int* value) final {} + void Visit(const char* key, bool* value) final {} + void Visit(const char* key, std::string* value) final {} + void Visit(const char* key, void** value) final {} + void Visit(const char* key, DataType* value) final {} + void Visit(const char* key, runtime::NDArray* value) final {} + void Visit(const char* key, ObjectRef* value) final { + const Object* obj = value->get(); + if (obj == nullptr) { + return; + } + stack_.push_back(obj); + if (obj->IsInstance()) { + const ArrayNode* array = static_cast(obj); + for (ObjectRef element : *array) { + this->Visit("", &element); + } + } else if (obj->IsInstance()) { + const MapNode* map = static_cast(obj); + for (std::pair kv : *map) { + this->Visit("", &kv.first); + this->Visit("", &kv.second); + } + } else { + vtable_->VisitAttrs(const_cast(obj), this); + } + if (is_var(GetRef(obj))) { + HandleVar(obj); + } + stack_.pop_back(); + } + + void HandleVar(const Object* var) { + if (common_prefix.count(var) == 0) { + common_prefix[var] = stack_; + return; + } + std::vector& a = common_prefix[var]; + std::vector& b = stack_; + int n = std::min(a.size(), b.size()); + for (int i = 0; i < n; ++i) { + if (a[i] != b[i]) { + a.resize(i); + break; + } + } + } + + ReflectionVTable* vtable_ = ReflectionVTable::Global(); + std::vector stack_; + + public: + runtime::TypedPackedFunc is_var; + std::unordered_map> common_prefix; + }; + Visitor visitor; + visitor.is_var = is_var; + visitor(root); + this->common_prefix = std::move(visitor.common_prefix); } IRDocsifier::IRDocsifier(Map ir_prefix) { auto n = make_object(); n->ir_prefix = std::move(ir_prefix); - n->dispatch_tokens.push_back(kDefaultDispatchToken); + n->dispatch_tokens.push_back(""); data_ = std::move(n); } @@ -43,65 +158,8 @@ IRDocsifier::FType& IRDocsifier::vtable() { return inst; } -RootNodeContainer::RootNodeContainer(ObjectRef root_node) { - auto n = make_object(); - n->root_node = std::move(root_node); - data_ = std::move(n); -} - -// Add a default dispatch for the RootNodeContainer to throw error. -// To add implementation for a new IR, RootNodeContainer needs to be -// registered under the dispatch token of that IR, like: -// \code -// TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) -// .set_dispatch("relax", [](TracedObject obj, IRDocsifier p) { -// const ObjectRef& root_node = obj.Get()->root_node; -// \\ More specialized logic for your IR. -// return p->AsDoc(MakeTraced(root_node)); -// }); -// \endcode -TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch([](TracedObject obj, IRDocsifier p) -> Doc { - String top_dispatch_token = p->dispatch_tokens.back(); - ICHECK_NE(top_dispatch_token, ""); - ICHECK(false) << "Printing IR " << top_dispatch_token << " is not implemented."; - throw; - }); - +TVM_REGISTER_NODE_TYPE(FrameNode); TVM_REGISTER_NODE_TYPE(IRDocsifierNode); -TVM_REGISTER_GLOBAL("script.printer.IRDocsifier").set_body_typed([](Map ir_prefix) { - return IRDocsifier(ir_prefix); -}); -TVM_REGISTER_GLOBAL("script.printer.IRDocsifierAsDoc") - .set_body_typed([](IRDocsifier p, ObjectRef obj, ObjectPath obj_path) { - return p->AsDoc(MakeTraced(obj, obj_path)); - }); - -TVM_REGISTER_GLOBAL("script.printer.IRDocsifierPushDispatchToken") - .set_body_typed([](IRDocsifier p, String token) { p->dispatch_tokens.push_back(token); }); -TVM_REGISTER_GLOBAL("script.printer.IRDocsifierPopDispatchToken").set_body_typed([](IRDocsifier p) { - p->dispatch_tokens.pop_back(); -}); - -TVM_REGISTER_GLOBAL("script.printer.IRDocsifierPushFrame") - .set_body_typed([](IRDocsifier p, Frame frame) { p->frames.push_back(frame); }); -TVM_REGISTER_GLOBAL("script.printer.IRDocsifierPopFrame").set_body_typed([](IRDocsifier p) { - p->frames.pop_back(); -}); - -TVM_REGISTER_GLOBAL("script.printer.IRDocsifierSetDispatch") - .set_body_typed([](String token, uint64_t type_index, runtime::PackedFunc f) { - IRDocsifier::vtable().set_dispatch(token, type_index, std::move(f)); - }); -TVM_REGISTER_GLOBAL("script.printer.IRDocsifierRemoveDispatch") - .set_body_typed([](String token, uint64_t type_index) { - IRDocsifier::vtable().remove_dispatch(token, type_index); - }); - -TVM_REGISTER_NODE_TYPE(RootNodeContainerNode); -TVM_REGISTER_GLOBAL("script.printer.RootNodeContainer").set_body_typed([](ObjectRef root_node) { - return RootNodeContainer(root_node); -}); } // namespace printer } // namespace script diff --git a/src/script/printer.cc b/src/script/printer/printer.cc similarity index 57% rename from src/script/printer.cc rename to src/script/printer/printer.cc index 051b774ba6ac..47fd0b89b09e 100644 --- a/src/script/printer.cc +++ b/src/script/printer/printer.cc @@ -16,38 +16,28 @@ * specific language governing permissions and limitations * under the License. */ - #include -#include -#include -#include -#include -#include +#include namespace tvm { namespace script { namespace printer { -String Script( // - const ObjectRef& root_node, // - String ir_name, // - Map ir_prefix, // - int indent_spaces, // - bool print_line_numbers, // - int num_context_lines, // - Optional path_to_underline // -) { - IRDocsifier ir_docsifier(ir_prefix); - - auto dispatch_ctx = ir_docsifier->WithDispatchToken(ir_name); - - Doc doc = ir_docsifier->AsDoc(MakeTraced(RootNodeContainer(root_node))); - +String Script(ObjectRef obj, Map ir_prefix, int indent_spaces, + bool print_line_numbers, int num_context_lines, + Optional path_to_underline) { + IRDocsifier d(ir_prefix); + Doc doc = d->AsDoc(obj, ObjectPath::Root()); return DocToPythonScript(doc, indent_spaces, print_line_numbers, num_context_lines, path_to_underline); } -TVM_REGISTER_GLOBAL("script.printer.Script").set_body_typed(&Script); +Default* Default::Instance() { + static Default inst; + return &inst; +} + +TVM_REGISTER_GLOBAL("script.printer.Script").set_body_typed(Script); } // namespace printer } // namespace script diff --git a/src/script/printer/tir/block.cc b/src/script/printer/tir/block.cc new file mode 100644 index 000000000000..f6dbf616a5a3 --- /dev/null +++ b/src/script/printer/tir/block.cc @@ -0,0 +1,150 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "./utils.h" + +namespace tvm { +namespace script { +namespace printer { + +Doc PrintBlock(IRDocsifier d, tir::Block block, ObjectPath block_p, // + Optional opt_realize, Optional opt_realize_p) { + With frame(d, block); + ICHECK_EQ(opt_realize.defined(), opt_realize_p.defined()); + const tir::BlockRealizeNode* realize = opt_realize.value().get(); + const ObjectPathNode* realize_p = opt_realize_p.get(); + // Step 1. Handle block var and block bindings + int n_vars = block->iter_vars.size(); + for (int i = 0; i < n_vars; ++i) { + tir::IterVar iter_var = block->iter_vars[i]; + ObjectPath iter_var_p = block_p->Attr("iter_var")->ArrayIndex(i); + ExprDoc rhs = TIR(d)->Attr("axis"); + if (iter_var->iter_type == tir::IterVarType::kDataPar) { + rhs = rhs->Attr("spatial"); + } else if (iter_var->iter_type == tir::IterVarType::kCommReduce) { + rhs = rhs->Attr("reduce"); + } else if (iter_var->iter_type == tir::IterVarType::kOrdered) { + rhs = rhs->Attr("scan"); + } else if (iter_var->iter_type == tir::IterVarType::kOpaque) { + rhs = rhs->Attr("opaque"); + } else { + LOG(FATAL) << "ValueError: Unknown IterVarType in block signature: " + << tir::IterVarType2String(iter_var->iter_type); + } + ExprDoc dom{nullptr}; + if (tir::is_zero(iter_var->dom->min)) { + ExprDoc extent = d->AsDoc(iter_var->dom->extent, // + iter_var_p->Attr("dom")->Attr("extent")); + dom = extent; + } else { + ExprDoc min = d->AsDoc(iter_var->dom->min, iter_var_p->Attr("dom")->Attr("min")); + ExprDoc max = d->AsDoc(iter_var->dom->min + iter_var->dom->extent, + iter_var_p->Attr("dom")->Attr("extent")); + dom = TupleDoc({min, max}); + } + if (realize) { + ExprDoc binding = d->AsDoc(realize->iter_values[i], // + realize_p->Attr("iter_values")->ArrayIndex(i)); + rhs = rhs->Call({dom, binding}); + } else { + rhs = rhs->Call({dom}); + } + (*frame)->stmts.push_back(AssignDoc(DefineVar(iter_var->var, *frame, d), rhs, NullOpt)); + } + // Step 2. Handle block predicate + if (realize) { + ICHECK(realize->predicate.defined() && realize->predicate->dtype.is_bool()); + if (!tir::is_one(realize->predicate)) { + (*frame)->stmts.push_back(ExprStmtDoc(TIR(d)->Attr("where")->Call( + {d->AsDoc(realize->predicate, realize_p->Attr("predicate"))}))); + } + } + // Step 3. Handle block read/write regions + { + Array reads; + for (int i = 0, n = block->reads.size(); i < n; ++i) { + reads.push_back(d->AsDoc(block->reads[i], block_p->Attr("reads")->ArrayIndex(i))); + } + (*frame)->stmts.push_back(ExprStmtDoc(TIR(d)->Attr("reads")->Call(reads))); + Array writes; + for (int i = 0, n = block->writes.size(); i < n; ++i) { + writes.push_back(d->AsDoc(block->writes[i], block_p->Attr("writes")->ArrayIndex(i))); + } + (*frame)->stmts.push_back(ExprStmtDoc(TIR(d)->Attr("writes")->Call(writes))); + } + // Step 4. Handle block attributes + if (!block->annotations.empty()) { + (*frame)->stmts.push_back(ExprStmtDoc( + TIR(d) + ->Attr("block_attr") + ->Call({d->AsDoc(block->annotations, block_p->Attr("annotations"))}))); + } + // Step 5. Handle `alloc_buffer` + for (int i = 0, n = block->alloc_buffers.size(); i < n; ++i) { + tir::Buffer buffer = block->alloc_buffers[i]; + ObjectPath buffer_p = block_p->Attr("alloc_buffers")->ArrayIndex(i); + IdDoc lhs = DefineBuffer(buffer, *frame, d); + ExprDoc rhs = BufferDecl(buffer, "alloc_buffer", {}, buffer_p, *frame, d); + (*frame)->stmts.push_back(AssignDoc(lhs, rhs, NullOpt)); + } + // Step 6. Handle `match_buffer` + for (int i = 0, n = block->match_buffers.size(); i < n; ++i) { + tir::MatchBufferRegion buffer_region = block->match_buffers[i]; + ObjectPath buffer_region_p = block_p->Attr("match_buffers")->ArrayIndex(i); + StmtDoc doc = d->AsDoc(buffer_region, buffer_region_p); + (*frame)->stmts.push_back(doc); + } + // Step 7. Handle init block + if (block->init.defined()) { + tir::Stmt init = block->init.value(); + With init_frame(d, init); + AsDocBody(init, block_p->Attr("init"), init_frame->get(), d); + (*frame)->stmts.push_back( + ScopeDoc(NullOpt, TIR(d)->Attr("init")->Call({}), (*init_frame)->stmts)); + } + // Step 8. Handle block body + AsDocBody(block->body, block_p->Attr("body"), frame->get(), d); + return ScopeDoc(NullOpt, TIR(d)->Attr("block")->Call({LiteralDoc::Str(block->name_hint)}), + (*frame)->stmts); +} + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch( + "", [](tir::BlockRealize realize, ObjectPath p, IRDocsifier d) -> Doc { + return PrintBlock(d, realize->block, p->Attr("block"), realize, p); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("", [](tir::Block block, ObjectPath p, IRDocsifier d) -> Doc { + return PrintBlock(d, block, p, NullOpt, NullOpt); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch( + "", [](tir::MatchBufferRegion stmt, ObjectPath p, IRDocsifier d) -> Doc { + Frame frame = d->frames.back(); + ExprDoc lhs = DefineBuffer(stmt->buffer, frame, d); + ExprDoc src_buffer = d->AsDoc(stmt->source, p->Attr("source")); + ExprDoc rhs = BufferDecl(stmt->buffer, "match_buffer", {src_buffer}, p->Attr("buffer"), + d->frames.back(), d); + return AssignDoc(lhs, rhs, NullOpt); + }); + +} // namespace printer +} // namespace script +} // namespace tvm diff --git a/src/script/printer/tir/buffer.cc b/src/script/printer/tir/buffer.cc new file mode 100644 index 000000000000..3e1d71af4acd --- /dev/null +++ b/src/script/printer/tir/buffer.cc @@ -0,0 +1,193 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include + +#include "./utils.h" + +namespace tvm { +namespace script { +namespace printer { + +Map BufferAttrs(const tir::Buffer& buffer, const ObjectPath& p, const Frame& frame, + const IRDocsifier& d) { + Map kwargs; + auto implicit_var_def = [&](const PrimExpr& e, const ObjectPath& p, const String& key) { + if (Optional doc = d->GetVarDoc(e)) { + kwargs.Set(key, doc.value()); + return false; + } + if (e->IsInstance()) { + d->Define(e, frame, [=]() { return d->AsDoc(buffer, p)->Attr(key); }); + return true; + } + kwargs.Set(key, d->AsDoc(e, p)); + return false; + }; + auto array_out_line_var_def = [&](const Array& array, const ObjectPath& p, + const String& key) { + int n = array.size(); + Array results; + results.reserve(n); + for (int i = 0; i < n; ++i) { + PrimExpr s = array[i]; + ObjectPath s_path = p->ArrayIndex(i); + // Add out-of-line definition for a new Var in shape + results.push_back(d->AsDoc(s, s_path)); + } + kwargs.Set(key, TupleDoc(results)); + }; + // Step 1. Handle `buffer.shape` + array_out_line_var_def(buffer->shape, p->Attr("shape"), "shape"); + // Step 2. Handle `buffer.dtype` + if (buffer->dtype != Default::BufferDType()) { + kwargs.Set("dtype", LiteralDoc::DataType(buffer->dtype)); + } + // Step 3. Handle `buffer.data` + implicit_var_def(buffer->data, p->Attr("data"), "data"); + // Step 4. Handle `buffer.strides` + if (!buffer->strides.empty()) { + array_out_line_var_def(buffer->strides, p->Attr("strides"), "strides"); + } + // Step 5. Handle `buffer.elem_offset` + bool needs_print_factor = false; + if (const auto* int_imm = buffer->elem_offset.as()) { + if (int_imm->value != 0) { + kwargs.Set("elem_offset", d->AsDoc(buffer->elem_offset, p->Attr("elem_offset"))); + } + } else { + needs_print_factor = + implicit_var_def(buffer->elem_offset, p->Attr("elem_offset"), "elem_offset"); + } + // Step 6. Handle `buffer.scope` + { + String scope = buffer.scope(); + if (scope != "global") { + kwargs.Set("scope", LiteralDoc::Str(scope)); + } + } + // Step 7. Handle `buffer.data_alignment` + if (buffer->data_alignment != runtime::kAllocAlignment) { + kwargs.Set("align", LiteralDoc::Int(buffer->data_alignment)); + } + // Step 8. Handle `buffer.offset_factor` + if (needs_print_factor || buffer->offset_factor != 1) { + kwargs.Set("offset_factor", LiteralDoc::Int(buffer->offset_factor)); + } + // Step 9. Handle `buffer.buffer_type` + if (buffer->buffer_type != tir::BufferType::kDefault) { + kwargs.Set("type", LiteralDoc::Str("auto")); + } + // Step 10. Handle `buffer.axis_separator` + if (!buffer->axis_separators.empty()) { + kwargs.Set("axis_separators", + d->AsDoc(buffer->axis_separators, p->Attr("axis_separators"))); + } + return kwargs; +} + +ExprDoc BufferCall(const ExprDoc& prefix, const Map& attrs, Array args) { + Array kwargs_keys; + Array kwargs_values; + for (String s : {"shape", "dtype"}) { + if (Optional doc = attrs.Get(s)) { + args.push_back(doc.value()); + } + } + for (String s : {"data", "strides", "elem_offset", "scope", "align", "offset_factor", "type", + "axis_separators"}) { + if (Optional doc = attrs.Get(s)) { + kwargs_keys.push_back(s); + kwargs_values.push_back(doc.value()); + } + } + return prefix->Call(args, kwargs_keys, kwargs_values); +} + +ExprDoc BufferDecl(const tir::Buffer& buffer, const String& method, const Array& args, + const ObjectPath& p, const Frame& frame, const IRDocsifier& d) { + return BufferCall(/*prefix=*/TIR(d)->Attr(method), + /*attrs=*/BufferAttrs(buffer, p, frame, d), + /*args=*/args); +} + +Doc BufferIndex(const PrimExpr& index, const ObjectPath& p, const IRDocsifier& d) { + if (const auto* ramp = index.as()) { + if (const auto* stride = ramp->stride.as()) { + ExprDoc start = d->AsDoc(ramp->base, p->Attr("base")); + ExprDoc stop = d->AsDoc(ramp->base + ramp->lanes * ramp->stride, p->Attr("lanes")); + Optional step = NullOpt; + if (stride->value != 1) { + step = d->AsDoc(ramp->stride, p->Attr("stride")); + } + return SliceDoc(start, stop, step); + } + } + return d->AsDoc(index, p); +} + +ExprDoc BufferIndices(const tir::Buffer& buffer, const Array& indices, + const ObjectPath& p, const IRDocsifier& d) { + int n = indices.size(); + Array indices_doc; + indices_doc.reserve(n); + for (int i = 0; i < n; ++i) { + indices_doc.push_back(BufferIndex(indices[i], p->Attr("indices")->ArrayIndex(i), d)); + } + return d->AsDoc(buffer, p->Attr("buffer"))[indices_doc]; +} + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch( + "", [](tir::BufferRegion buffer_region, ObjectPath p, IRDocsifier d) -> Doc { + ExprDoc prefix = d->AsDoc(buffer_region->buffer, p->Attr("buffer")); + p = p->Attr("region"); + Array region = buffer_region->region; + int n = region.size(); + Array indices; + indices.reserve(n); + for (int i = 0; i < n; ++i) { + Range range = region[i]; + ExprDoc min = d->AsDoc(range->min, p->ArrayIndex(i)->Attr("min")); + if (tir::is_one(range->extent)) { + indices.push_back(min); + } else { + ExprDoc max = + d->AsDoc(range->min + range->extent, p->ArrayIndex(i)->Attr("extent")); + indices.push_back(SliceDoc(min, max, NullOpt)); + } + } + return prefix[indices]; + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch( // + "", [](tir::BufferStore store, ObjectPath p, IRDocsifier d) -> Doc { + return AssignDoc(/*lhs=*/BufferIndices(store->buffer, store->indices, p, d), + /*rhs=*/d->AsDoc(store->value, p->Attr("value")), NullOpt); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch( // + "", [](tir::BufferLoad load, ObjectPath p, IRDocsifier d) -> Doc { + return BufferIndices(load->buffer, load->indices, p, d); + }); + +} // namespace printer +} // namespace script +} // namespace tvm diff --git a/src/script/printer/tir/expr.cc b/src/script/printer/tir/expr.cc new file mode 100644 index 000000000000..f9b4eb621447 --- /dev/null +++ b/src/script/printer/tir/expr.cc @@ -0,0 +1,299 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include + +#include "./utils.h" + +namespace tvm { +namespace script { +namespace printer { + +Doc PrintVar(const tir::Var& var, const ObjectPath& p, const IRDocsifier& d) { + if (!d->IsVarDefined(var)) { + if (Optional opt_f = FindLowestVarDef(var, d)) { + ExprDoc lhs = DefineVar(var, opt_f.value(), d); + Type type = var->type_annotation; + if (const auto* ptr_type = type.as()) { + ICHECK(ptr_type->element_type->IsInstance()); + ExprDoc rhs = d->AsDoc(type, p->Attr("type_annotation")); + opt_f.value()->stmts.push_back(AssignDoc(lhs, rhs, NullOpt)); + } else { + ExprDoc rhs = TIR(d)->Attr("var")->Call({LiteralDoc::DataType(var->dtype)}); + opt_f.value()->stmts.push_back(AssignDoc(lhs, rhs, NullOpt)); + } + } + } + if (Optional doc = d->GetVarDoc(var)) { + return doc.value(); + } + LOG(FATAL) << "IndexError: Variable is not defined in the environment: " << var; +} + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) // + .set_dispatch("", [](tir::Var var, ObjectPath p, IRDocsifier d) -> Doc { + return PrintVar(var, p, d); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) // + .set_dispatch("", [](tir::SizeVar var, ObjectPath p, IRDocsifier d) -> Doc { + return PrintVar(var, p, d); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("", [](tir::IterVar var, ObjectPath p, IRDocsifier d) -> Doc { + return TIR(d) + ->Attr("iter_var") + ->Call({ + d->AsDoc(var->var, p->Attr("var")), + d->AsDoc(var->dom, p->Attr("dom")), + LiteralDoc::Str(IterVarType2String(var->iter_type)), + LiteralDoc::Str(var->thread_tag), + }); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) // + .set_dispatch("", [](tir::Buffer buffer, ObjectPath p, IRDocsifier d) -> Doc { + if (!d->IsVarDefined(buffer)) { + if (Optional opt_f = FindLowestVarDef(buffer, d)) { + ExprDoc lhs = DefineBuffer(buffer, opt_f.value(), d); + ExprDoc rhs = BufferDecl(buffer, "buffer_decl", // TODO(@junrushao): name confusing + {}, p, opt_f.value(), d); + opt_f.value()->stmts.push_back(AssignDoc(lhs, rhs, NullOpt)); + } + } + if (Optional doc = d->GetVarDoc(buffer)) { + return doc.value(); + } + LOG(FATAL) << "IndexError: Buffer is not defined in the environment: " << buffer; + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("", [](tir::Not node, ObjectPath p, IRDocsifier d) -> Doc { + ExprDoc a = d->AsDoc(node->a, p->Attr("a")); + if (a->IsInstance()) { + return TIR(d)->Attr("Not")->Call({a}); + } + return OperationDoc(OperationDocNode::Kind::kNot, {a}); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("", [](tir::StringImm s, ObjectPath p, IRDocsifier d) -> Doc { + return d->AsDoc(s->value, p->Attr("value")); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("", [](tir::Cast cast, ObjectPath p, IRDocsifier d) -> Doc { + ExprDoc dtype = LiteralDoc::DataType(cast->dtype); + ExprDoc value = d->AsDoc(cast->value, p->Attr("value")); + return TIR(d)->Attr("Cast")->Call({dtype, value}); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("", [](tir::Select select, ObjectPath p, IRDocsifier d) -> Doc { + return TIR(d)->Attr("Select")->Call({ + d->AsDoc(select->condition, p->Attr("condition")), + d->AsDoc(select->true_value, p->Attr("true_value")), + d->AsDoc(select->false_value, p->Attr("false_value")), + }); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("", [](tir::Ramp ramp, ObjectPath p, IRDocsifier d) -> Doc { + return TIR(d)->Attr("Ramp")->Call({ + d->AsDoc(ramp->base, p->Attr("base")), + d->AsDoc(ramp->stride, p->Attr("stride")), + LiteralDoc::Int(ramp->lanes), + }); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("", [](tir::Broadcast bc, ObjectPath p, IRDocsifier d) -> Doc { + return TIR(d) + ->Attr("Broadcast") + ->Call({ + d->AsDoc(bc->value, p->Attr("value")), + LiteralDoc::Int(bc->lanes), + }); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch( // + "", [](tir::Shuffle shuffle, ObjectPath p, IRDocsifier d) -> Doc { + return TIR(d)->Attr("Shuffle")->Call({ + d->AsDoc(shuffle->vectors, p->Attr("vectors")), + d->AsDoc(shuffle->indices, p->Attr("indices")), + }); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch( // + "", [](tir::CommReducer r, ObjectPath p, IRDocsifier d) -> Doc { + ICHECK_EQ(r->lhs.size(), r->rhs.size()); + LambdaDoc lambda{nullptr}; + { + With f(d, r); + int n_vars = r->lhs.size(); + Array vars; + vars.reserve(n_vars + n_vars); + for (int i = 0; i < n_vars; ++i) { + vars.push_back(DefineVar(r->lhs[i], *f, d)); + } + for (int i = 0; i < n_vars; ++i) { + vars.push_back(DefineVar(r->rhs[i], *f, d)); + } + int n_results = r->result.size(); + Array results; + results.reserve(n_results); + for (int i = 0; i < n_results; ++i) { + results.push_back(d->AsDoc(r->result[i], p->Attr("result")->ArrayIndex(i))); + } + if (results.size() == 1) { + lambda = LambdaDoc(vars, results[0]); + } else { + lambda = LambdaDoc(vars, TupleDoc(results)); + } + } + ExprDoc id = d->AsDoc(r->identity_element, p->Attr("identity_element")); + return TIR(d)->Attr("comm_reducer")->Call({lambda, id}); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("", [](tir::Let let, ObjectPath p, IRDocsifier d) -> Doc { + return TIR(d)->Attr("let")->Call({ + d->AsDoc(let->var, p->Attr("var")), + d->AsDoc(let->value, p->Attr("value")), + d->AsDoc(let->body, p->Attr("body")), + }); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("", [](tir::Call call, ObjectPath p, IRDocsifier d) -> Doc { + static const OpAttrMap& op_names = + Op::GetAttrMap("TScriptPrinterName"); + static const std::unordered_set dtype_first_arg = { + tir::builtin::reinterpret().get(), + tir::builtin::call_extern().get(), + tir::builtin::call_llvm_intrin().get(), // + tir::builtin::call_llvm_pure_intrin().get(), // + tir::builtin::call_pure_extern().get(), // + tir::builtin::ptx_mma().get(), + tir::builtin::ptx_mma_sp().get(), + tir::builtin::ptx_ldmatrix().get(), + tir::builtin::ptx_cp_async().get(), + tir::builtin::mma_store().get(), + tir::builtin::mma_fill().get(), + tir::builtin::vectorlow().get(), + tir::builtin::vectorhigh().get(), + tir::builtin::vectorcombine().get(), + Op::Get("tir.type_annotation").get(), + }; + static const std::unordered_set dtype_last_arg = { + tir::builtin::tvm_struct_get().get(), + }; + ExprDoc prefix{nullptr}; + if (const auto* op = call->op.as()) { + String name = op_names[GetRef(op)]; + prefix = TIR(d)->Attr(name); + } else if (const auto* gv = call->op.as()) { + prefix = LiteralDoc::Str(gv->name_hint); + } else { + LOG(FATAL) << "call: " << call; + } + Array args; + int n_args = call->args.size(); + args.reserve(n_args + 1); + if (dtype_first_arg.count(call->op.get())) { + args.push_back(LiteralDoc::DataType(call->dtype)); + } + for (int i = 0; i < n_args; ++i) { + args.push_back(d->AsDoc(call->args[i], p->Attr("args")->ArrayIndex(i))); + } + if (dtype_last_arg.count(call->op.get())) { + args.push_back(LiteralDoc::DataType(call->dtype)); + } + return prefix->Call(args); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("", [](tir::Any any, ObjectPath p, IRDocsifier d) -> Doc { + return TIR(d)->Attr("Any")->Call({}); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("", [](tir::Reduce r, ObjectPath p, IRDocsifier d) -> Doc { + LOG(FATAL) << "ValueError: Reduce should never exist in TIR: " << r; + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch( + "", [](tir::ProducerLoad load, ObjectPath p, IRDocsifier d) -> Doc { + LOG(FATAL) << "ValueError: ProducerLoad should never exist in TIR: " << load; + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("", [](tir::Load load, ObjectPath p, IRDocsifier d) -> Doc { + LOG(FATAL) << "ValueError: Load has been deprecated for BufferLoad: " << load; + }); + +#define TVM_SCRIPT_PRINTER_DEF_BINARY(NodeType, OpString) \ + TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) \ + .set_dispatch("", \ + [](tir::NodeType node, ObjectPath p, IRDocsifier d) -> Doc { \ + ExprDoc a = d->AsDoc(node->a, p->Attr("a")); \ + ExprDoc b = d->AsDoc(node->b, p->Attr("b")); \ + return TIR(d)->Attr(OpString)->Call({a, b}); \ + }); + +#define TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(NodeType, OpString, OpKind) \ + TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) \ + .set_dispatch( \ + "", [](tir::NodeType node, ObjectPath p, IRDocsifier d) -> Doc { \ + ExprDoc a = d->AsDoc(node->a, p->Attr("a")); \ + ExprDoc b = d->AsDoc(node->b, p->Attr("b")); \ + if (a->IsInstance() && b->IsInstance()) { \ + return TIR(d)->Attr(OpString)->Call({a, b}); \ + } \ + return OperationDoc(OperationDocNode::Kind::OpKind, {a, b}); \ + }); + +TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(Add, "Add", kAdd); +TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(Sub, "Sub", kSub); +TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(Mul, "Mul", kMult); +TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(Div, "Div", kDiv); +TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(FloorDiv, "FloorDiv", kFloorDiv); +TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(FloorMod, "FloorMod", kMod); +TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(LT, "LT", kLt); +TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(LE, "LE", kLtE); +TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(EQ, "EQ", kEq); +TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(NE, "NE", kNotEq); +TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(GT, "GT", kGt); +TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(GE, "GE", kGtE); +TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(And, "And", kAnd); +TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(Or, "Or", kOr); + +TVM_SCRIPT_PRINTER_DEF_BINARY(Mod, "truncmod"); +TVM_SCRIPT_PRINTER_DEF_BINARY(Min, "min"); +TVM_SCRIPT_PRINTER_DEF_BINARY(Max, "max"); + +#undef TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR +#undef TVM_SCRIPT_PRINTER_DEF_BINARY + +} // namespace printer +} // namespace script +} // namespace tvm diff --git a/src/script/printer/tir/for_loop.cc b/src/script/printer/tir/for_loop.cc new file mode 100644 index 000000000000..6a375935bd79 --- /dev/null +++ b/src/script/printer/tir/for_loop.cc @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "./utils.h" + +namespace tvm { +namespace script { +namespace printer { + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("", [](tir::For loop, ObjectPath p, IRDocsifier d) -> Doc { + // Step 1. Check syntactic sugar: `T.grid` + std::vector grid; + std::unordered_set grid_loop_vars; + auto f_var_dep = [&grid_loop_vars](const PrimExpr& e) -> bool { + return tir::UsesVar(e, [&grid_loop_vars](const tir::VarNode* v) -> bool { // + return grid_loop_vars.count(v); + }); + }; + for (const tir::ForNode* l = loop.get(); l != nullptr; l = l->body.as()) { + ICHECK(l->loop_var->dtype == l->min->dtype); + ICHECK(l->loop_var->dtype == l->extent->dtype); + if (l->kind != tir::ForKind::kSerial || // + !tir::is_zero(l->min) || // + !l->annotations.empty() || // + f_var_dep(l->extent)) { + break; + } + grid.push_back(l); + grid_loop_vars.insert(l->loop_var.get()); + } + With f(d, loop); + // Step 2. Construct `T.grid` + if (grid.size() > 1) { + int n = grid.size(); + Array lhs; + Array rhs; + lhs.reserve(n); + rhs.reserve(n); + for (int i = 0; i < n; ++i) { + const tir::ForNode* loop = grid[i]; + lhs.push_back(DefineVar(loop->loop_var, *f, d)); + rhs.push_back(d->AsDoc(loop->extent, p->Attr("extent"))); + p = p->Attr("body"); + } + AsDocBody(grid.back()->body, p, (*f).get(), d); + return ForDoc(TupleDoc(lhs), TIR(d)->Attr("grid")->Call(rhs), (*f)->stmts); + } + // Step 3. If not `T.grid`, print loop kind accordingly + IdDoc lhs = DefineVar(loop->loop_var, *f, d); + Optional min = NullOpt; + Optional max = NullOpt; + Optional annotations = NullOpt; + Optional thread = NullOpt; + if (tir::is_zero(loop->min)) { + max = d->AsDoc(loop->extent, p->Attr("extent")); + } else { + min = d->AsDoc(loop->min, p->Attr("min")); + max = d->AsDoc(loop->min + loop->extent, p->Attr("extent")); + } + if (!loop->annotations.empty()) { + annotations = d->AsDoc(loop->annotations, p->Attr("annotations")); + } + ExprDoc prefix = TIR(d); + if (loop->kind == tir::ForKind::kSerial) { + if (loop->annotations.empty()) { + prefix = IdDoc("range"); + } else { + prefix = prefix->Attr("serial"); + } + } else if (loop->kind == tir::ForKind::kParallel) { + prefix = prefix->Attr("parallel"); + } else if (loop->kind == tir::ForKind::kUnrolled) { + prefix = prefix->Attr("unroll"); + } else if (loop->kind == tir::ForKind::kVectorized) { + prefix = prefix->Attr("vectorized"); + } else if (loop->kind == tir::ForKind::kThreadBinding) { + prefix = prefix->Attr("thread_binding"); + thread = LiteralDoc::Str(loop->thread_binding.value()->thread_tag); + } else { + LOG(FATAL) << "ValueError: Unknown ForKind: " << tir::ForKind2String(loop->kind); + } + Array args; + Array kwargs_keys; + Array kwargs_values; + if (min.defined()) { + args.push_back(min.value()); + } + if (max.defined()) { + args.push_back(max.value()); + } + if (thread.defined()) { + kwargs_keys.push_back("thread"); + kwargs_values.push_back(thread.value()); + } + if (annotations.defined()) { + kwargs_keys.push_back("annotations"); + kwargs_values.push_back(annotations.value()); + } + ExprDoc rhs = prefix->Call(args, kwargs_keys, kwargs_values); + AsDocBody(loop->body, p, (*f).get(), d); + return ForDoc(lhs, rhs, (*f)->stmts); + }); + +} // namespace printer +} // namespace script +} // namespace tvm diff --git a/src/script/printer/tir/function.cc b/src/script/printer/tir/function.cc new file mode 100644 index 000000000000..d47a60209e43 --- /dev/null +++ b/src/script/printer/tir/function.cc @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "./utils.h" + +namespace tvm { +namespace script { +namespace printer { + +String FindFunctionName(const IRDocsifier& d, const tir::PrimFunc& f) { + if (!d->mod.defined()) { + return "main"; + } + for (const auto& kv : d->mod.value()->functions) { + if (kv.second.same_as(f)) { + return kv.first->name_hint; + } + } + return "main"; +} + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("", [](tir::PrimFunc func, ObjectPath p, IRDocsifier d) -> Doc { + d->SetCommonPrefix(func, [](const ObjectRef& obj) { + return obj->IsInstance() || obj->IsInstance(); + }); + With frame(d, func); + (*frame)->AddDispatchToken(d, "tir"); + int n_args = func->params.size(); + // Step 1. Handle `func->params` + Array args; + args.reserve(n_args); + for (int i = 0; i < n_args; ++i) { + tir::Var var = func->params[i]; + ObjectPath var_p = p->Attr("params")->ArrayIndex(i); + ExprDoc a = d->AsDoc(var->type_annotation, var_p->Attr("type_annotation")); + args.push_back(AssignDoc(DefineVar(var, *frame, d), NullOpt, a)); + } + // Step 2. Handle `func->attrs` + if (func->attrs.defined() && !func->attrs->dict.empty()) { + (*frame)->stmts.push_back( + ExprStmtDoc(TIR(d) + ->Attr("func_attr") // + ->Call({d->AsDoc(func->attrs, p->Attr("attrs"))}))); + } + // Step 3. Handle `func->buffer_map` + for (int i = 0; i < n_args; ++i) { + tir::Var param = func->params[i]; + if (func->buffer_map.count(param)) { + tir::Buffer buffer = func->buffer_map[param]; + ExprDoc param = args[i]->lhs; + ObjectPath buffer_p = p->Attr("buffer_map")->MapValue(param); + ExprDoc lhs = + DefineBuffer(buffer, *frame, d); // TODO(@junrushao): switch `lhs` and `rhs` + ExprDoc rhs = BufferDecl(buffer, "match_buffer", {param}, buffer_p, *frame, d); + (*frame)->stmts.push_back(AssignDoc(lhs, rhs, NullOpt)); + } + } + // Step 4. Handle `func->body` + AsDocBody(func->body, p->Attr("body"), frame->get(), d); + return FunctionDoc( + /*name=*/IdDoc(FindFunctionName(d, func)), + /*args=*/args, + /*decorators=*/{TIR(d)->Attr("prim_func")}, + /*return_type=*/d->AsDoc(func->ret_type, p->Attr("ret_type")), + /*body=*/(*frame)->stmts); + }); + +} // namespace printer +} // namespace script +} // namespace tvm diff --git a/src/script/printer/tir/ir.cc b/src/script/printer/tir/ir.cc new file mode 100644 index 000000000000..f4e3762fc022 --- /dev/null +++ b/src/script/printer/tir/ir.cc @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include + +#include "./utils.h" + +namespace tvm { +namespace script { +namespace printer { + +TVM_REGISTER_NODE_TYPE(TIRFrameNode); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("", [](IntImm imm, ObjectPath p, IRDocsifier d) -> Doc { + DataType dtype = imm->dtype; + if (dtype == Default::IntDType()) { + return LiteralDoc::Int(imm->value); + } else if (dtype == DataType::Bool()) { + return LiteralDoc::Boolean(imm->value); + } else { + return TIR(d) // + ->Attr(runtime::DLDataType2String(dtype)) + ->Call({LiteralDoc::Int(imm->value)}); + } + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("", [](FloatImm imm, ObjectPath p, IRDocsifier d) -> Doc { + DataType dtype = imm->dtype; + if (dtype == Default::FloatDType()) { + return LiteralDoc::Float(imm->value); + } else { + return TIR(d) + ->Attr(runtime::DLDataType2String(dtype)) + ->Call({LiteralDoc::Float(imm->value)}); + } + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("", [](Range range, ObjectPath p, IRDocsifier d) -> Doc { + return TIR(d)->Attr("Range")->Call({ + d->AsDoc(range->min, p->Attr("min")), + d->AsDoc(range->extent, p->Attr("extent")), + }); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("", [](PrimType ty, ObjectPath p, IRDocsifier d) -> Doc { + std::string dtype = ty->dtype.is_void() ? "void" : runtime::DLDataType2String(ty->dtype); + return TIR(d)->Attr(dtype); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("", [](PointerType ty, ObjectPath p, IRDocsifier d) -> Doc { + ExprDoc element_type = d->AsDoc(ty->element_type, p->Attr("element_type")); + if (ty->storage_scope == "") { + return TIR(d)->Attr("Ptr")->Call({element_type}); + } else { + return TIR(d)->Attr("Ptr")->Call({element_type, LiteralDoc::Str(ty->storage_scope)}); + } + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("", [](TupleType ty, ObjectPath p, IRDocsifier d) -> Doc { + if (ty->fields.empty()) { + return LiteralDoc::None(); + } + return TIR(d) // + ->Attr("Tuple") + ->Call(d->AsDoc(ty->fields, p->Attr("fields"))->elements); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("", [](Target target, ObjectPath p, IRDocsifier d) -> Doc { + Map config = target->Export(); + return TIR(d)->Attr("target")->Call({d->AsDoc(config, p)}); + }); + +} // namespace printer +} // namespace script +} // namespace tvm diff --git a/src/script/printer/tir/stmt.cc b/src/script/printer/tir/stmt.cc new file mode 100644 index 000000000000..03e5657d24b7 --- /dev/null +++ b/src/script/printer/tir/stmt.cc @@ -0,0 +1,374 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../../../tir/transforms/ir_utils.h" +#include "./utils.h" + +namespace tvm { +namespace script { +namespace printer { + +Doc DoConciseScoping(const Optional& lhs, const ExprDoc& rhs, Array* stmts, + bool concise_scoping) { + if (concise_scoping) { + if (lhs.defined()) { + stmts->insert(stmts->begin(), AssignDoc(lhs.value(), rhs, NullOpt)); + } else { + stmts->insert(stmts->begin(), ExprStmtDoc(rhs)); + } + return StmtBlockDoc(*stmts); + } else { + return ScopeDoc(lhs, rhs, *stmts); + } +} + +bool AllowConciseScoping(const IRDocsifier& d) { + ICHECK(!d->frames.empty()); + if (const auto* f = d->frames.back().as()) { + return f->allow_concise_scoping; + } + LOG(FATAL) << "NotImplementedError: fragment printing"; +} + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("", [](tir::Evaluate eval, ObjectPath p, IRDocsifier d) -> Doc { + ExprDoc value = d->AsDoc(eval->value, p->Attr("value")); + if (eval->value->IsInstance()) { + return ExprStmtDoc(value); + } + return ExprStmtDoc(TIR(d)->Attr("evaluate")->Call({value})); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("", [](tir::LetStmt stmt, ObjectPath p, IRDocsifier d) -> Doc { + bool concise = AllowConciseScoping(d); + ExprDoc rhs = d->AsDoc(stmt->value, p->Attr("value")); + With f(d, stmt); + ExprDoc lhs = d->IsVarDefined(stmt->var) ? d->GetVarDoc(stmt->var).value() + : DefineVar(stmt->var, *f, d); + AsDocBody(stmt->body, p->Attr("body"), f->get(), d); + Array* stmts = &(*f)->stmts; + if (concise) { + Type type = stmt->var->type_annotation; + Optional type_doc = + d->AsDoc(type, p->Attr("var")->Attr("type_annotation")); + if (const auto* tuple_type = type.as()) { + if (tuple_type->fields.empty()) { + type_doc = NullOpt; + } + } + stmts->insert(stmts->begin(), AssignDoc(lhs, rhs, type_doc)); + return StmtBlockDoc(*stmts); + } else { + rhs = TIR(d)->Attr("let")->Call({lhs, rhs}); + return ScopeDoc(NullOpt, rhs, *stmts); + } + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch( + "", [](tir::AssertStmt stmt, ObjectPath p, IRDocsifier d) -> Doc { + bool concise = AllowConciseScoping(d); + ExprDoc cond = d->AsDoc(stmt->condition, p->Attr("condition")); + ExprDoc msg = d->AsDoc(stmt->message, p->Attr("message")); + With f(d, stmt); + AsDocBody(stmt->body, p->Attr("body"), f->get(), d); + if (concise) { + Array* stmts = &(*f)->stmts; + stmts->insert(stmts->begin(), AssertDoc(cond, msg)); + return StmtBlockDoc(*stmts); + } + return ScopeDoc(NullOpt, TIR(d)->Attr("Assert")->Call({cond, msg}), (*f)->stmts); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("", [](tir::While stmt, ObjectPath p, IRDocsifier d) -> Doc { + ExprDoc cond = d->AsDoc(stmt->condition, p->Attr("condition")); + With f(d, stmt); + AsDocBody(stmt->body, p->Attr("body"), f->get(), d); + return WhileDoc(cond, (*f)->stmts); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch( // + "", [](tir::DeclBuffer stmt, ObjectPath p, IRDocsifier d) -> Doc { + bool concise = AllowConciseScoping(d); + ExprDoc rhs = + BufferDecl(stmt->buffer, "decl_buffer", {}, p->Attr("buffer"), d->frames.back(), d); + With f(d, stmt); + ExprDoc lhs = DefineBuffer(stmt->buffer, *f, d); + AsDocBody(stmt->body, p->Attr("body"), f->get(), d); + return DoConciseScoping(lhs, rhs, &(*f)->stmts, concise); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch( // + "", [](tir::IfThenElse stmt, ObjectPath p, IRDocsifier d) -> Doc { + ExprDoc cond = d->AsDoc(stmt->condition, p->Attr("condition")); + Array then_branch; + Array else_branch; + if (stmt->then_case.defined()) { + With f(d, stmt->then_case); + AsDocBody(stmt->then_case, p->Attr("then_case"), f->get(), d); + then_branch = (*f)->stmts; + } + if (stmt->else_case.defined()) { + With f(d, stmt->else_case); + AsDocBody(stmt->else_case.value(), p->Attr("else_case"), f->get(), d); + else_branch = (*f)->stmts; + } + return IfDoc(cond, then_branch, else_branch); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("", [](tir::SeqStmt stmt, ObjectPath p, IRDocsifier d) -> Doc { + // TODO(@junrushao): revisit for fragment printing + With f(d, stmt); + AsDocBody(stmt, p, f->get(), d); + return StmtBlockDoc((*f)->stmts); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch( // + "", [](tir::Prefetch stmt, ObjectPath p, IRDocsifier d) -> Doc { + return ExprStmtDoc(TIR(d) + ->Attr("prefetch") + ->Call({ + d->AsDoc(stmt->buffer, p->Attr("buffer")), + d->AsDoc(stmt->bounds, p->Attr("bounds")), + })); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch( // + "", [](tir::Allocate stmt, ObjectPath p, IRDocsifier d) -> Doc { + bool concise = AllowConciseScoping(d); + String storage_scope = tir::GetPtrStorageScope(stmt->buffer_var); + Array args; + Array kwargs_keys; + Array kwargs_values; + args.push_back(d->AsDoc(stmt->extents, p->Attr("extents"))); + args.push_back(LiteralDoc::DataType(stmt->dtype)); + args.push_back(LiteralDoc::Str(storage_scope)); + if (!tir::is_one(stmt->condition)) { + args.push_back(d->AsDoc(stmt->condition, p->Attr("condition"))); + } + if (!stmt->annotations.empty()) { + kwargs_keys.push_back("annotations"); + kwargs_values.push_back(d->AsDoc(stmt->annotations, p->Attr("annotations"))); + } + ExprDoc lhs = DefineVar(stmt->buffer_var, d->frames.back(), d); + With f(d, stmt); + ExprDoc rhs = TIR(d)->Attr("allocate")->Call(args, kwargs_keys, kwargs_values); + AsDocBody(stmt->body, p->Attr("body"), f->get(), d); + return DoConciseScoping(lhs, rhs, &(*f)->stmts, concise); + }); + +template +ExprDoc PrintNDArray(::tvm::runtime::NDArray arr) { + // FIXME(@junrushao): this is a hack and can be wrong in most of the cases + constexpr int NUM_PRINT = 200; + int ndim = arr->ndim; + int tot_dim = 1; + for (int i = 0; i < ndim; i++) { + tot_dim *= arr->shape[i]; + } + Array result; + T* data_ptr = reinterpret_cast(arr->data); + runtime::DataType dtype = arr.DataType(); + for (int i = 0; i < tot_dim; i++) { + if (dtype.is_float()) { + result.push_back(LiteralDoc::Float(data_ptr[i])); + } else { + result.push_back(LiteralDoc::Int(data_ptr[i])); + } + if (i == NUM_PRINT) { + break; + } + } + return ListDoc(result); +} + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch( + "", [](tir::AllocateConst stmt, ObjectPath p, IRDocsifier d) -> Doc { + bool concise = AllowConciseScoping(d); + String storage_scope = tir::GetPtrStorageScope(stmt->buffer_var); + Array args; + Array kwargs_keys; + Array kwargs_values; + ExprDoc data_doc{nullptr}; + if (stmt->dtype.is_int()) { + if (stmt->dtype.bits() == 8) { + data_doc = PrintNDArray(stmt->data.value()); + } else if (stmt->dtype.bits() == 16) { + data_doc = PrintNDArray(stmt->data.value()); + } else if (stmt->dtype.bits() == 32) { + data_doc = PrintNDArray(stmt->data.value()); + } else if (stmt->dtype.bits() == 64) { + data_doc = PrintNDArray(stmt->data.value()); + } else { + LOG(FATAL) << "DataType not supported"; + } + } else if (stmt->dtype.is_uint()) { + if (stmt->dtype.bits() == 8) { + data_doc = PrintNDArray(stmt->data.value()); + } else if (stmt->dtype.bits() == 16) { + data_doc = PrintNDArray(stmt->data.value()); + } else if (stmt->dtype.bits() == 32) { + data_doc = PrintNDArray(stmt->data.value()); + } else if (stmt->dtype.bits() == 64) { + data_doc = PrintNDArray(stmt->data.value()); + } else { + LOG(FATAL) << "DataType not supported"; + } + } else if (stmt->dtype.is_float()) { + if (stmt->dtype.bits() == 16) { + data_doc = PrintNDArray(stmt->data.value()); + } else if (stmt->dtype.bits() == 32) { + data_doc = PrintNDArray(stmt->data.value()); + } else if (stmt->dtype.bits() == 64) { + data_doc = PrintNDArray(stmt->data.value()); + } else { + LOG(FATAL) << "DataType not supported"; + } + } else { + LOG(FATAL) << "DataType not supported"; + } + args.push_back(data_doc); + args.push_back(LiteralDoc::DataType(stmt->dtype)); + args.push_back(d->AsDoc(stmt->extents, p->Attr("extents"))); + ExprDoc rhs = TIR(d)->Attr("allocate_const")->Call(args, kwargs_keys, kwargs_values); + With f(d, stmt); + ExprDoc lhs = DefineVar(stmt->buffer_var, *f, d); + AsDocBody(stmt->body, p->Attr("body"), f->get(), d); + return DoConciseScoping(lhs, rhs, &(*f)->stmts, concise); + }); + +ExprDoc DocsifyBufferRealize(const tir::BufferRealizeNode* stmt, Optional value, // + ObjectPath p, IRDocsifier d) { + ExprDoc buffer = d->AsDoc(stmt->buffer, p->Attr("buffer")); + { + Array bounds; + bounds.reserve(stmt->bounds.size()); + for (int i = 0, n = stmt->bounds.size(); i < n; ++i) { + Range range = stmt->bounds[i]; + ObjectPath range_p = p->Attr("bounds")->ArrayIndex(i); + bounds.push_back( + SliceDoc(d->AsDoc(range->min, range_p->Attr("min")), + d->AsDoc(range->min + range->extent, range_p->Attr("extent")), // + NullOpt)); + } + buffer = buffer[bounds]; + } + Array args{buffer}; + Array kwargs_keys; + Array kwargs_values; + if (value.defined()) { + args.push_back(value.value()); + } + if (!tir::is_one(stmt->condition)) { + kwargs_keys.push_back("condition"); + kwargs_values.push_back(d->AsDoc(stmt->condition, p->Attr("condition"))); + } + return TIR(d)->Attr("realize")->Call(args, kwargs_keys, kwargs_values); +} + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch( // + "", [](tir::BufferRealize stmt, ObjectPath p, IRDocsifier d) -> Doc { + bool concise = AllowConciseScoping(d); + ExprDoc rhs = DocsifyBufferRealize(stmt.get(), NullOpt, p, d); + With f(d, stmt); + AsDocBody(stmt->body, p->Attr("body"), f->get(), d); + return DoConciseScoping(NullOpt, rhs, &(*f)->stmts, concise); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch( // + "", [](tir::AttrStmt stmt, ObjectPath p, IRDocsifier d) -> Doc { + bool concise = AllowConciseScoping(d); + Optional rhs = NullOpt; + tir::Stmt body = stmt->body; + ObjectPath body_p = p->Attr("body"); + if (stmt->attr_key == "realize_scope") { + if (const auto* realize = stmt->body.as()) { + if (realize->buffer.same_as(stmt->node)) { + rhs = + DocsifyBufferRealize(realize, + /*value=*/d->AsDoc(stmt->value, p->Attr("value")), + /*p=*/p->Attr("body"), d); + body = realize->body; + body_p = body_p->Attr("body"); + } + } + } + if (stmt->attr_key == "thread_extent" || stmt->attr_key == "virtual_thread") { + if (const auto* iter_var = stmt->node.as()) { + if (!d->IsVarDefined(iter_var->var)) { + // `DefineVar` is not used here because a more specific name is desirable + Frame f = FindLowestVarDef(iter_var->var, d).value(); + DefineVar(iter_var->var, f, d); + f->stmts.push_back( + AssignDoc(d->AsDoc(iter_var->var, p->Attr("node")->Attr("var")), + TIR(d) // + ->Attr("env_thread") + ->Call({LiteralDoc::Str(iter_var->thread_tag)}), // + NullOpt)); + } + rhs = TIR(d) + ->Attr("launch_thread") + ->Call({ + d->AsDoc(iter_var->var, p->Attr("node")), + d->AsDoc(stmt->value, p->Attr("value")), + }); + } + } + if (!rhs.defined()) { + rhs = TIR(d)->Attr("attr")->Call({ + d->AsDoc(stmt->node, p->Attr("node")), + LiteralDoc::Str(stmt->attr_key), + d->AsDoc(stmt->value, p->Attr("value")), + }); + } + With f(d, stmt); + AsDocBody(body, body_p, f->get(), d); + return DoConciseScoping(NullOpt, rhs.value(), &(*f)->stmts, concise); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch( // + "", [](tir::ProducerRealize stmt, ObjectPath p, IRDocsifier d) -> Doc { + LOG(FATAL) << "ValueError: ProducerRealize should never exist in TIR: " << stmt; + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch( // + "", [](tir::ProducerStore stmt, ObjectPath p, IRDocsifier d) -> Doc { + LOG(FATAL) << "ValueError: ProducerStore should never exist in TIR: " << stmt; + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch( // + "", [](tir::Store stmt, ObjectPath p, IRDocsifier d) -> Doc { + LOG(FATAL) << "ValueError: Store has been deprecated for BufferStore: " << stmt; + }); + +} // namespace printer +} // namespace script +} // namespace tvm diff --git a/src/script/printer/tir/utils.h b/src/script/printer/tir/utils.h new file mode 100644 index 000000000000..6cae378d0e69 --- /dev/null +++ b/src/script/printer/tir/utils.h @@ -0,0 +1,176 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_SCRIPT_PRINTER_TIR_UTILS_H_ +#define TVM_SCRIPT_PRINTER_TIR_UTILS_H_ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace tvm { +namespace script { +namespace printer { + +/*! \brief A printer frame for TIR fragment */ +class TIRFrameNode : public FrameNode { + public: + /*! \brief The TIR fragment the frame corresponds to */ + ObjectRef tir; + /*! \brief Whether or not the frame allows concise scoping */ + bool allow_concise_scoping{false}; + + void VisitAttrs(AttrVisitor* v) { + FrameNode::VisitAttrs(v); + v->Visit("tir", &tir); + v->Visit("allow_concise_scoping", &allow_concise_scoping); + } + + static constexpr const char* _type_key = "script.printer.TIRFrame"; + TVM_DECLARE_FINAL_OBJECT_INFO(TIRFrameNode, FrameNode); +}; + +/*! \brief Managed reference to TIRFrameNode */ +class TIRFrame : public Frame { + public: + /*! \brief Constructor */ + explicit TIRFrame(const IRDocsifier& d, const ObjectRef& tir) { + ObjectPtr n = make_object(); + n->stmts.clear(); + n->d = d.get(); + n->tir = tir; + data_ = std::move(n); + } + + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(TIRFrame, Frame, TIRFrameNode); +}; + +/*! \brief Creates the TIR common prefix, which is by default `T` */ +inline IdDoc TIR(const IRDocsifier& d) { // + return IdDoc(d->ir_prefix.Get("tir").value_or("T")); +} + +/*! + * \brief Defines a variable in the IRDocsifier at the given frame, + * and returns the corresponding IdDoc + * \param var The variable to define + * \param d The IRDocsifier + * \param frame The frame to define the variable in + * \return The IdDoc corresponding to the variable + */ +inline IdDoc DefineVar(const tir::Var& var, const Frame& frame, const IRDocsifier& d) { + return d->Define(var, frame, var->name_hint.empty() ? "v" : var->name_hint); +} + +/*! + * \brief Defines a buffer in the IRDocsifier at the given frame, + * and returns the corresponding IdDoc + * \param buffer The buffer to define + * \param frame The frame to define the buffer in + * \param d The IRDocsifier + * \return The IdDoc corresponding to the buffer + */ +inline IdDoc DefineBuffer(const tir::Buffer& buffer, const Frame& frame, const IRDocsifier& d) { + return d->Define(buffer, frame, buffer->name.empty() ? "buffer" : buffer->name); +} + +/*! + * \brief Recursively process the body statements of a TIR fragment represented by a frame + * \param stmt The body statement to process + * \param p The object path + * \param f The frame + * \param d The IRDocsifier + */ +inline void AsDocBody(const tir::Stmt& stmt, ObjectPath p, TIRFrameNode* f, const IRDocsifier& d) { + if (const auto* seq_stmt = stmt.as()) { + Array body = seq_stmt->seq; + p = p->Attr("seq"); + for (int i = 0, n = body.size(); i < n; ++i) { + f->allow_concise_scoping = (i == n - 1); + Doc doc = d->AsDoc(body[i], p->ArrayIndex(i)); + if (const auto* block = doc.as()) { + f->stmts.insert(f->stmts.end(), block->stmts.begin(), block->stmts.end()); + } else { + f->stmts.push_back(Downcast(doc)); + } + } + } else { + f->allow_concise_scoping = true; + Doc doc = d->AsDoc(stmt, p); + if (const auto* block = doc.as()) { + f->stmts.insert(f->stmts.end(), block->stmts.begin(), block->stmts.end()); + } else { + f->stmts.push_back(Downcast(doc)); + } + } +} + +/*! + * \brief Find the top frame in the stack that could place a var definition + * \param var The var to be defined + * \param d The IRDocsifier + * \return The frame that could place the var definition + */ +inline Optional FindLowestVarDef(const ObjectRef& var, const IRDocsifier& d) { + if (!d->common_prefix.count(var.get())) { + return NullOpt; + } + int n_frames = d->frames.size(); + std::unordered_map tir_to_frame; + tir_to_frame.reserve(n_frames); + for (int i = n_frames - 1; i >= 0; --i) { + if (const auto* f = d->frames[i].as()) { + tir_to_frame[f->tir.get()] = f; + } + } + const std::vector& path = d->common_prefix.at(var.get()); + for (auto it = path.rbegin(); it != path.rend(); ++it) { + if (tir_to_frame.count(*it)) { + return GetRef(tir_to_frame.at(*it)); + } + } + return NullOpt; +} + +/*! + * \brief Declare and define a buffer + * \param buffer The buffer to be defined + * \param method The method used to declare the buffer + * \param args The extra arguments used to declare the buffer + * \param p The object path + * \param f The frame + * \param d The IRDocsifier + * \return The ExprDoc corresponding to the buffer declaration + */ +ExprDoc BufferDecl(const tir::Buffer& buffer, const String& method, const Array& args, + const ObjectPath& p, const Frame& frame, const IRDocsifier& d); + +} // namespace printer +} // namespace script +} // namespace tvm + +#endif // TVM_SCRIPT_PRINTER_TIR_UTILS_H_ diff --git a/src/script/printer/traced_object_functor.cc b/src/script/printer/traced_object_functor.cc deleted file mode 100644 index 43160c7f4be4..000000000000 --- a/src/script/printer/traced_object_functor.cc +++ /dev/null @@ -1,85 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -#include - -namespace tvm { -namespace script { -namespace printer { - -const runtime::PackedFunc* GetDispatchFunctionForToken(const DispatchTable& table, - const String& token, uint32_t type_index) { - auto it = table.find(token); - if (it == table.end()) { - return nullptr; - } - const std::vector& tab = it->second; - if (type_index >= tab.size()) { - return nullptr; - } - const PackedFunc* f = &tab[type_index]; - if (f->defined()) { - return f; - } else { - return nullptr; - } -} - -const runtime::PackedFunc& GetDispatchFunction(const DispatchTable& dispatch_table, - const String& token, uint32_t type_index) { - if (const runtime::PackedFunc* pf = - GetDispatchFunctionForToken(dispatch_table, token, type_index)) { - return *pf; - } else if (const runtime::PackedFunc* pf = - GetDispatchFunctionForToken(dispatch_table, kDefaultDispatchToken, type_index)) { - // Fallback to function with the default dispatch token - return *pf; - } else { - ICHECK(false) << "ObjectFunctor calls un-registered function on type: " - << runtime::Object::TypeIndex2Key(type_index) << " (token: " << token << ")"; - throw; - } -} - -void SetDispatchFunction(DispatchTable* dispatch_table, const String& token, uint32_t type_index, - runtime::PackedFunc f) { - std::vector* table = &(*dispatch_table)[token]; - if (table->size() <= type_index) { - table->resize(type_index + 1, nullptr); - } - runtime::PackedFunc& slot = (*table)[type_index]; - if (slot != nullptr) { - ICHECK(false) << "Dispatch for type is already registered: " - << runtime::Object::TypeIndex2Key(type_index); - } - slot = f; -} - -void RemoveDispatchFunction(DispatchTable* dispatch_table, const String& token, - uint32_t type_index) { - std::vector* table = &(*dispatch_table)[token]; - if (table->size() <= type_index) { - return; - } - (*table)[type_index] = nullptr; -} - -} // namespace printer -} // namespace script -} // namespace tvm diff --git a/src/script/printer/utils.h b/src/script/printer/utils.h deleted file mode 100644 index abe7ce5e9a88..000000000000 --- a/src/script/printer/utils.h +++ /dev/null @@ -1,93 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -#ifndef TVM_SCRIPT_PRINTER_UTILS_H_ -#define TVM_SCRIPT_PRINTER_UTILS_H_ - -#include -#include - -#include - -namespace tvm { -namespace script { -namespace printer { - -template -Array AsDocArray(const TracedArray& refs, const IRDocsifier& ir_docsifier) { - Array result; - for (auto ref : refs) { - result.push_back(ir_docsifier->AsExprDoc(ref)); - } - return result; -} - -template -Array AsDocArray(std::initializer_list&& refs, const IRDocsifier& ir_docsifier) { - Array result; - for (auto& ref : refs) { - result.push_back(ir_docsifier->AsExprDoc(ref)); - } - return result; -} - -template -Array AsExprDocArray(const TracedArray& refs, const IRDocsifier& ir_docsifier) { - return AsDocArray(refs, ir_docsifier); -} - -template -Array AsExprDocArray(std::initializer_list&& refs, - const IRDocsifier& ir_docsifier) { - return AsDocArray(std::move(refs), ir_docsifier); -} - -inline DictDoc AsDictDoc(const TracedMap& dict, - const IRDocsifier& ir_docsifier) { - Array keys; - Array values; - - for (auto p : dict) { - keys.push_back(LiteralDoc::Str(p.first)); - values.push_back(ir_docsifier->AsExprDoc(p.second)); - } - - auto doc = DictDoc(keys, values); - doc->source_paths.push_back(dict.GetPath()); - return doc; -} - -template -inline ListDoc AsListDoc(const TracedArray& arr, const IRDocsifier& ir_docsifier) { - auto ret = ListDoc(AsExprDocArray(arr, ir_docsifier)); - ret->source_paths.push_back(arr.GetPath()); - return ret; -} - -template -inline TupleDoc AsTupleDoc(const TracedArray& arr, const IRDocsifier& ir_docsifier) { - auto ret = TupleDoc(AsExprDocArray(arr, ir_docsifier)); - ret->source_paths.push_back(arr.GetPath()); - return ret; -} - -} // namespace printer -} // namespace script -} // namespace tvm - -#endif // TVM_SCRIPT_PRINTER_UTILS_H_ diff --git a/src/script/printer/var_table.cc b/src/script/printer/var_table.cc deleted file mode 100644 index 62d8b2f66cc2..000000000000 --- a/src/script/printer/var_table.cc +++ /dev/null @@ -1,109 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -#include -#include -#include -#include -#include - -namespace tvm { -namespace script { -namespace printer { - -String GenerateUniqueName(const String& name_hint, std::unordered_set* defined_names) { - String name = name_hint; - for (int i = 1; !defined_names->insert(name).second; ++i) { - name = name_hint + "_" + std::to_string(i); - } - return name; -} - -IdDoc VarTableNode::Define(const ObjectRef& obj, const String& name_hint, - const ObjectPath& object_path, const Frame& frame) { - String name = GenerateUniqueName(name_hint, &this->defined_names); - DocFactory doc_factory = [name]() { return IdDoc(name); }; - - auto result = obj2info.insert({obj, VariableInfo{std::move(doc_factory), name}}); - ICHECK(result.second) << "Duplicated object: " << obj; - - IdDoc def_doc(name); - def_doc->source_paths.push_back(object_path); - - frame->AddExitCallback([this, obj]() { this->RemoveVar(obj); }); - - return def_doc; -} - -void VarTableNode::DefineByDoc(const ObjectRef& obj, DocFactory doc_factory, const Frame& frame) { - ICHECK(obj2info.find(obj) == obj2info.end()) << "Duplicated object: " << obj; - - ICHECK(!doc_factory()->IsInstance()) - << "VarTableNode::Define cannot be used for variable that's mapped to IdDoc."; - - obj2info.insert({obj, VariableInfo{std::move(doc_factory), NullOpt}}); - - frame->AddExitCallback([this, obj]() { this->RemoveVar(obj); }); -} - -Optional VarTableNode::GetVarDoc(const ObjectRef& obj, - const ObjectPath& object_path) const { - auto it = obj2info.find(obj); - if (it == obj2info.end()) { - return NullOpt; - } - ExprDoc doc = it->second.doc_factory(); - doc->source_paths.push_back(object_path); - return doc; -} - -bool VarTableNode::IsVarDefined(const ObjectRef& obj) const { return obj2info.count(obj); } - -void VarTableNode::RemoveVar(const ObjectRef& obj) { - auto it = obj2info.find(obj); - ICHECK(it != obj2info.end()) << "No such object: " << obj; - - if (it->second.name.defined()) { - defined_names.erase(it->second.name.value()); - } - obj2info.erase(it); -} - -VarTable::VarTable() { data_ = make_object(); } - -TVM_REGISTER_NODE_TYPE(VarTableNode); -TVM_REGISTER_GLOBAL("script.printer.VarTable").set_body_typed([]() { return VarTable(); }); -TVM_REGISTER_GLOBAL("script.printer.VarTableDefine") - .set_body_method(&VarTableNode::Define); -TVM_REGISTER_GLOBAL("script.printer.VarTableDefineByDoc") - .set_body_typed([](VarTable var_table, const ObjectRef& obj, runtime::PackedFunc factory, - Frame frame) { - var_table->DefineByDoc( - obj, [f = std::move(factory)]() { return f(); }, frame); - }); -TVM_REGISTER_GLOBAL("script.printer.VarTableGetVarDoc") - .set_body_method, const ObjectRef&, - const ObjectPath&>(&VarTableNode::GetVarDoc); -TVM_REGISTER_GLOBAL("script.printer.VarTableIsVarDefined") - .set_body_method(&VarTableNode::IsVarDefined); - -} // namespace printer -} // namespace script -} // namespace tvm diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index a8d8936c905a..af6997a72aa3 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -1088,7 +1088,7 @@ PrimExpr TypeAnnotation(DataType dtype, Span span) { return tir::Call(dtype, op, {}, span); } -TVM_REGISTER_OP("tir.type_annotation") +TVM_TIR_REGISTER_OP("type_annotation") .set_attr("TCallEffectKind", Integer(CallEffectKind::kPure)); } // namespace tir diff --git a/src/tir/op/builtin.cc b/src/tir/op/builtin.cc index 56ecba9e9ed9..dc3208f484e3 100644 --- a/src/tir/op/builtin.cc +++ b/src/tir/op/builtin.cc @@ -36,7 +36,7 @@ namespace builtin { static const Op& op = Op::Get("tir." #OpName); \ return op; \ } \ - TVM_REGISTER_OP("tir." #OpName) + TVM_TIR_REGISTER_OP(#OpName) TIR_DEFINE_BUILTIN_FUNC(reinterpret) .set_attr("TCallEffectKind", Integer(CallEffectKind::kPure)) @@ -181,10 +181,12 @@ TIR_DEFINE_BUILTIN_FUNC(tvm_stack_make_array) // When num_inputs are not set, the function is assumed to be variable length. TIR_DEFINE_BUILTIN_FUNC(tvm_call_packed) - .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)) + .set_attr("TScriptPrinterName", String("call_packed"), /*plevel=*/20); TIR_DEFINE_BUILTIN_FUNC(tvm_call_cpacked) - .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)) + .set_attr("TScriptPrinterName", String("call_cpacked"), /*plevel=*/20); TIR_DEFINE_BUILTIN_FUNC(tvm_call_trace_packed) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); @@ -198,10 +200,14 @@ TIR_DEFINE_BUILTIN_FUNC(tvm_thread_context) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); TIR_DEFINE_BUILTIN_FUNC(tvm_call_packed_lowered) - .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)) + .set_attr("TScriptPrinterName", String("call_packed_lowered"), + /*plevel=*/20); TIR_DEFINE_BUILTIN_FUNC(tvm_call_cpacked_lowered) - .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)) + .set_attr("TScriptPrinterName", String("call_cpacked_lowered"), + /*plevel=*/20); TIR_DEFINE_BUILTIN_FUNC(tvm_call_trace_packed_lowered) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc index 044d8fd08da5..078e32ca57c7 100644 --- a/src/tir/op/op.cc +++ b/src/tir/op/op.cc @@ -39,13 +39,13 @@ namespace tvm { using namespace tir; // macro to register an unary op -#define TIR_REGISTER_PURE_UNARY_OP(OpName) \ - TVM_REGISTER_OP(OpName).set_num_inputs(1).set_attr( \ +#define TVM_TIR_REGISTER_PURE_UNARY_OP(OpName) \ + TVM_TIR_REGISTER_OP(OpName).set_num_inputs(1).set_attr( \ "TCallEffectKind", Integer(CallEffectKind::kPure)) // macro to register an binary op -#define TIR_REGISTER_PURE_BINARY_OP(OpName) \ - TVM_REGISTER_OP(OpName).set_num_inputs(2).set_attr( \ +#define TVM_TIR_REGISTER_PURE_BINARY_OP(OpName) \ + TVM_TIR_REGISTER_OP(OpName).set_num_inputs(2).set_attr( \ "TCallEffectKind", Integer(CallEffectKind::kPure)) runtime::DataType GetRuntimeDataType(const Type& type) { @@ -657,7 +657,7 @@ PrimExpr pow(PrimExpr x, PrimExpr y, Span span) { return tir::Call(x.dtype(), op, {x, y}, span); } -TIR_REGISTER_PURE_BINARY_OP("tir.pow").set_attr("TVectorizable", true); +TVM_TIR_REGISTER_PURE_BINARY_OP("pow").set_attr("TVectorizable", true); // abs PrimExpr abs(PrimExpr x, Span span) { @@ -685,7 +685,7 @@ PrimExpr abs(PrimExpr x, Span span) { } } -TIR_REGISTER_PURE_UNARY_OP("tir.fabs").set_attr("TVectorizable", true); +TVM_TIR_REGISTER_PURE_UNARY_OP("fabs").set_attr("TVectorizable", true); // isnan PrimExpr isnan(PrimExpr x, Span span) { @@ -783,7 +783,7 @@ PrimExpr fmod(PrimExpr x, PrimExpr y, Span span) { return tir::Call(x.dtype(), op, {x, y}, span); } -TIR_REGISTER_PURE_UNARY_OP("tir.fmod"); +TVM_TIR_REGISTER_PURE_UNARY_OP("fmod"); // floor PrimExpr floor(PrimExpr x, Span span) { @@ -797,7 +797,7 @@ PrimExpr floor(PrimExpr x, Span span) { return tir::Call(x.dtype(), op, {x}, span); } -TIR_REGISTER_PURE_UNARY_OP("tir.floor").set_attr("TVectorizable", true); +TVM_TIR_REGISTER_PURE_UNARY_OP("floor").set_attr("TVectorizable", true); // ceil PrimExpr ceil(PrimExpr x, Span span) { @@ -811,7 +811,7 @@ PrimExpr ceil(PrimExpr x, Span span) { return tir::Call(x.dtype(), op, {x}, span); } -TIR_REGISTER_PURE_UNARY_OP("tir.ceil").set_attr("TVectorizable", true); +TVM_TIR_REGISTER_PURE_UNARY_OP("ceil").set_attr("TVectorizable", true); // round PrimExpr round(PrimExpr x, Span span) { @@ -825,7 +825,7 @@ PrimExpr round(PrimExpr x, Span span) { return tir::Call(x.dtype(), op, {x}, span); } -TIR_REGISTER_PURE_UNARY_OP("tir.round").set_attr("TVectorizable", true); +TVM_TIR_REGISTER_PURE_UNARY_OP("round").set_attr("TVectorizable", true); // nearbyint PrimExpr nearbyint(PrimExpr x, Span span) { @@ -839,7 +839,7 @@ PrimExpr nearbyint(PrimExpr x, Span span) { return tir::Call(x.dtype(), op, {x}, span); } -TIR_REGISTER_PURE_UNARY_OP("tir.nearbyint"); +TVM_TIR_REGISTER_PURE_UNARY_OP("nearbyint"); // trunc PrimExpr trunc(PrimExpr x, Span span) { @@ -856,67 +856,77 @@ PrimExpr trunc(PrimExpr x, Span span) { return tir::Call(x.dtype(), op, {x}, span); } -TIR_REGISTER_PURE_UNARY_OP("tir.trunc").set_attr("TVectorizable", true); +TVM_TIR_REGISTER_PURE_UNARY_OP("trunc").set_attr("TVectorizable", true); // unary op registration. -TIR_REGISTER_PURE_UNARY_OP("tir.exp").set_attr("TVectorizable", true); +TVM_TIR_REGISTER_PURE_UNARY_OP("exp").set_attr("TVectorizable", true); -TIR_REGISTER_PURE_UNARY_OP("tir.exp2").set_attr("TVectorizable", true); +TVM_TIR_REGISTER_PURE_UNARY_OP("exp2").set_attr("TVectorizable", true); -TIR_REGISTER_PURE_UNARY_OP("tir.exp10").set_attr("TVectorizable", true); +TVM_TIR_REGISTER_PURE_UNARY_OP("exp10").set_attr("TVectorizable", true); -TIR_REGISTER_PURE_UNARY_OP("tir.erf"); +TVM_TIR_REGISTER_PURE_UNARY_OP("erf"); -TIR_REGISTER_PURE_UNARY_OP("tir.tanh").set_attr("TVectorizable", true); +TVM_TIR_REGISTER_PURE_UNARY_OP("tanh").set_attr("TVectorizable", true); -TIR_REGISTER_PURE_UNARY_OP("tir.sigmoid").set_attr("TVectorizable", true); +TVM_TIR_REGISTER_PURE_UNARY_OP("sigmoid").set_attr("TVectorizable", true); -TIR_REGISTER_PURE_UNARY_OP("tir.sqrt").set_attr("TVectorizable", true); +TVM_TIR_REGISTER_PURE_UNARY_OP("sqrt").set_attr("TVectorizable", true); -TIR_REGISTER_PURE_UNARY_OP("tir.rsqrt"); +TVM_TIR_REGISTER_PURE_UNARY_OP("rsqrt"); -TIR_REGISTER_PURE_UNARY_OP("tir.log").set_attr("TVectorizable", true); +TVM_TIR_REGISTER_PURE_UNARY_OP("log").set_attr("TVectorizable", true); -TIR_REGISTER_PURE_UNARY_OP("tir.log2").set_attr("TVectorizable", true); +TVM_TIR_REGISTER_PURE_UNARY_OP("log2").set_attr("TVectorizable", true); -TIR_REGISTER_PURE_UNARY_OP("tir.log1p"); +TVM_TIR_REGISTER_PURE_UNARY_OP("log1p"); -TIR_REGISTER_PURE_UNARY_OP("tir.log10").set_attr("TVectorizable", true); +TVM_TIR_REGISTER_PURE_UNARY_OP("log10").set_attr("TVectorizable", true); -TIR_REGISTER_PURE_UNARY_OP("tir.tan").set_attr("TVectorizable", true); +TVM_TIR_REGISTER_PURE_UNARY_OP("tan").set_attr("TVectorizable", true); -TIR_REGISTER_PURE_UNARY_OP("tir.cos").set_attr("TVectorizable", true); +TVM_TIR_REGISTER_PURE_UNARY_OP("cos").set_attr("TVectorizable", true); -TIR_REGISTER_PURE_UNARY_OP("tir.cosh").set_attr("TVectorizable", true); +TVM_TIR_REGISTER_PURE_UNARY_OP("cosh").set_attr("TVectorizable", true); -TIR_REGISTER_PURE_UNARY_OP("tir.sin").set_attr("TVectorizable", true); +TVM_TIR_REGISTER_PURE_UNARY_OP("sin").set_attr("TVectorizable", true); -TIR_REGISTER_PURE_UNARY_OP("tir.sinh").set_attr("TVectorizable", true); +TVM_TIR_REGISTER_PURE_UNARY_OP("sinh").set_attr("TVectorizable", true); -TIR_REGISTER_PURE_UNARY_OP("tir.asin"); +TVM_TIR_REGISTER_PURE_UNARY_OP("asin"); -TIR_REGISTER_PURE_UNARY_OP("tir.acos"); +TVM_TIR_REGISTER_PURE_UNARY_OP("acos"); -TIR_REGISTER_PURE_UNARY_OP("tir.atan"); +TVM_TIR_REGISTER_PURE_UNARY_OP("atan"); -TIR_REGISTER_PURE_UNARY_OP("tir.acosh"); +TVM_TIR_REGISTER_PURE_UNARY_OP("acosh"); -TIR_REGISTER_PURE_UNARY_OP("tir.asinh"); +TVM_TIR_REGISTER_PURE_UNARY_OP("asinh"); -TIR_REGISTER_PURE_UNARY_OP("tir.atanh"); +TVM_TIR_REGISTER_PURE_UNARY_OP("atanh"); -TIR_REGISTER_PURE_UNARY_OP("tir.clz"); +TVM_TIR_REGISTER_PURE_UNARY_OP("clz"); // binary intrinsics -TIR_REGISTER_PURE_BINARY_OP("tir.atan2"); +TVM_TIR_REGISTER_PURE_BINARY_OP("atan2"); -TIR_REGISTER_PURE_BINARY_OP("tir.nextafter"); +TVM_TIR_REGISTER_PURE_BINARY_OP("nextafter"); -TIR_REGISTER_PURE_BINARY_OP("tir.hypot"); +TVM_TIR_REGISTER_PURE_BINARY_OP("hypot"); -TIR_REGISTER_PURE_BINARY_OP("tir.copysign"); +TVM_TIR_REGISTER_PURE_BINARY_OP("copysign"); -TIR_REGISTER_PURE_BINARY_OP("tir.ldexp"); +TVM_TIR_REGISTER_PURE_BINARY_OP("ldexp"); + +TVM_TIR_REGISTER_OP("TVMBackendAllocWorkspace") + .set_num_inputs(5) + .set_attr("TGlobalSymbol", "TVMBackendAllocWorkspace") + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); + +TVM_TIR_REGISTER_OP("TVMBackendFreeWorkspace") + .set_num_inputs(3) + .set_attr("TGlobalSymbol", "TVMBackendFreeWorkspace") + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); // expose basic functions to node namespace TVM_REGISTER_GLOBAL("node._const").set_body([](TVMArgs args, TVMRetValue* ret) { diff --git a/src/tir/op/runtime.cc b/src/tir/op/runtime.cc deleted file mode 100644 index adabae9e75f7..000000000000 --- a/src/tir/op/runtime.cc +++ /dev/null @@ -1,41 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file tir/op/runtime.cc - * \brief TIR ops for runtime functions. - */ -#include -#include - -namespace tvm { -namespace tir { - -TVM_REGISTER_OP("tir.TVMBackendAllocWorkspace") - .set_num_inputs(5) - .set_attr("TGlobalSymbol", "TVMBackendAllocWorkspace") - .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); - -TVM_REGISTER_OP("tir.TVMBackendFreeWorkspace") - .set_num_inputs(3) - .set_attr("TGlobalSymbol", "TVMBackendFreeWorkspace") - .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); - -} // namespace tir -} // namespace tvm diff --git a/tests/cpp/traced_object_test.cc b/tests/cpp/traced_object_test.cc deleted file mode 100644 index 7890a67eef95..000000000000 --- a/tests/cpp/traced_object_test.cc +++ /dev/null @@ -1,268 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -#include -#include -#include -#include -#include - -using namespace tvm; - -namespace { - -class DummyObjectNode : public Object { - public: - void VisitAttrs(AttrVisitor* v) {} - - static constexpr const char* _type_key = "TracedObjectTestDummyObject"; - TVM_DECLARE_FINAL_OBJECT_INFO(DummyObjectNode, Object); -}; - -class DummyObject : public ObjectRef { - public: - TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(DummyObject, ObjectRef, DummyObjectNode); -}; - -TVM_REGISTER_NODE_TYPE(DummyObjectNode); - -class ObjectWithAttrsNode : public Object { - public: - int64_t int64_attr = 5; - Map map_attr; - Array array_attr; - DummyObject obj_attr; - - ObjectWithAttrsNode() : obj_attr(make_object()) {} - - void VisitAttrs(AttrVisitor* v) { - v->Visit("int64_attr", &int64_attr); - v->Visit("map_attr", &map_attr); - v->Visit("array_attr", &array_attr); - v->Visit("obj_attr", &obj_attr); - } - - static constexpr const char* _type_key = "TracedObjectTestObjectWithAttrs"; - TVM_DECLARE_FINAL_OBJECT_INFO(ObjectWithAttrsNode, Object); -}; - -class ObjectWithAttrs : public ObjectRef { - public: - TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(ObjectWithAttrs, ObjectRef, ObjectWithAttrsNode); -}; - -TVM_REGISTER_NODE_TYPE(ObjectWithAttrsNode); - -} // anonymous namespace - -TEST(TracedObjectTest, MakeTraced_RootObject) { - ObjectWithAttrs root(make_object()); - auto root_traced = MakeTraced(root); - - static_assert(std::is_same>::value); - ICHECK(root_traced.GetPath()->PathsEqual(ObjectPath::Root())); - ICHECK_EQ(root_traced.Get().get(), root.get()); -} - -TEST(TracedObjectTest, MakeTraced_WithPath) { - ObjectWithAttrs obj(make_object()); - auto traced = MakeTraced(obj, ObjectPath::Root()->Attr("foo")); - - static_assert(std::is_same>::value); - ICHECK(traced.GetPath()->PathsEqual(ObjectPath::Root()->Attr("foo"))); - ICHECK_EQ(traced.Get().get(), obj.get()); -} - -TEST(TracedObjectTest, TracedObject_ImplicitConversionFromDerived) { - DummyObject obj(make_object()); - auto traced = MakeTraced(obj); - static_assert(std::is_same>::value); - - // Check that TracedObject is implicitly converted to TracedObject - auto base_traced = [](const TracedObject& base) { return base; }(traced); - - static_assert(std::is_same>::value); -} - -TEST(TracedObjectTest, TracedObject_GetAttr_ObjectRef) { - ObjectWithAttrs root(make_object()); - auto root_traced = MakeTraced(root); - auto obj_attr = root_traced.GetAttr(&ObjectWithAttrsNode::obj_attr); - static_assert(std::is_same>::value); - ICHECK(obj_attr.GetPath()->PathsEqual(ObjectPath::Root()->Attr("obj_attr"))); - ICHECK_EQ(obj_attr.Get().get(), root->obj_attr.get()); -} - -TEST(TracedObjectTest, TracedObject_GetAttr_Map) { - ObjectWithAttrs root(make_object()); - root->map_attr.Set("foo", "bar"); - - auto root_traced = MakeTraced(root); - auto map_attr = root_traced.GetAttr(&ObjectWithAttrsNode::map_attr); - static_assert(std::is_same>::value); - ICHECK(map_attr.GetPath()->PathsEqual(ObjectPath::Root()->Attr("map_attr"))); - ICHECK_EQ(map_attr.Get().get(), root->map_attr.get()); - - auto map_val = map_attr.at("foo"); - ICHECK_EQ(map_val.Get(), "bar"); - ICHECK( - map_val.GetPath()->PathsEqual(ObjectPath::Root()->Attr("map_attr")->MapValue(String("foo")))); -} - -TEST(TracedObjectTest, TracedObject_GetAttr_Array) { - ObjectWithAttrs root(make_object()); - root->array_attr.push_back("foo"); - root->array_attr.push_back("bar"); - - auto root_traced = MakeTraced(root); - auto array_attr = root_traced.GetAttr(&ObjectWithAttrsNode::array_attr); - static_assert(std::is_same>::value); - ICHECK(array_attr.GetPath()->PathsEqual(ObjectPath::Root()->Attr("array_attr"))); - ICHECK_EQ(array_attr.Get().get(), root->array_attr.get()); - - auto array_val = array_attr[1]; - ICHECK_EQ(array_val.Get(), "bar"); - ICHECK(array_val.GetPath()->PathsEqual(ObjectPath::Root()->Attr("array_attr")->ArrayIndex(1))); -} - -TEST(TracedObjectTest, TracedObject_GetAttr_Int64) { - ObjectWithAttrs root(make_object()); - auto root_traced = MakeTraced(root); - - auto int64_attr = root_traced.GetAttr(&ObjectWithAttrsNode::int64_attr); - static_assert(std::is_same>::value); - ICHECK_EQ(int64_attr.Get(), 5); - ICHECK(int64_attr.GetPath()->PathsEqual(ObjectPath::Root()->Attr("int64_attr"))); -} - -TEST(TracedObjectTest, TracedObject_IsInstance) { - ObjectRef dummy(make_object()); - auto traced = MakeTraced(dummy); - ICHECK(traced.IsInstance()); - ICHECK(!traced.IsInstance()); -} - -TEST(TracedObjectTest, TracedObject_Downcast) { - ObjectRef root(make_object()); - auto traced = MakeTraced(root); - - auto as_dummy = traced.Downcast(); - static_assert(std::is_same>::value); - ICHECK_EQ(as_dummy.Get(), root); - - // Try downcasting to a wrong type - bool caught = false; - try { - traced.Downcast(); - } catch (std::exception& e) { - caught = strstr(e.what(), - "Downcast from TracedObjectTestDummyObject to TracedObjectTestObjectWithAttrs " - "failed") != nullptr; - } - ICHECK(caught); -} - -TEST(TracedObjectTest, TracedObject_TryDowncast) { - ObjectRef root(make_object()); - auto traced = MakeTraced(root); - - auto as_dummy = traced.TryDowncast(); - static_assert(std::is_same>::value); - ICHECK(as_dummy.defined()); - ICHECK_EQ(as_dummy.value().Get(), root); - - // Try downcasting to a wrong type - ICHECK(!traced.TryDowncast().defined()); -} - -TEST(TracedObjectTest, TracedMap_At) { - Map m({{"k1", "foo"}, {"k2", "bar"}}); - auto traced = MakeTraced(m); - - auto traced_foo = traced.at("k1"); - static_assert(std::is_same>::value); - ICHECK_EQ(traced_foo.Get(), "foo"); - ICHECK(traced_foo.GetPath()->PathsEqual(ObjectPath::Root()->MapValue(String("k1")))); -} - -TEST(TracedObjectTest, TracedMap_Iterator) { - Map m({{"k1", "foo"}, {"k2", "bar"}}); - auto traced = MakeTraced(m); - - size_t k1_count = 0; - size_t k2_count = 0; - - for (const auto& kv : traced) { - if (kv.first == "k1") { - ++k1_count; - ICHECK_EQ(kv.second.Get(), "foo"); - ICHECK(kv.second.GetPath()->PathsEqual(ObjectPath::Root()->MapValue(String("k1")))); - } else if (kv.first == "k2") { - ++k2_count; - ICHECK_EQ(kv.second.Get(), "bar"); - ICHECK(kv.second.GetPath()->PathsEqual(ObjectPath::Root()->MapValue(String("k2")))); - } else { - ICHECK(false); - } - } - - ICHECK_EQ(k1_count, 1); - ICHECK_EQ(k2_count, 1); -} - -TEST(TracedObjectTest, TracedArray_Index) { - Array a = {"foo", "bar"}; - auto traced = MakeTraced(a); - - auto traced_bar = traced[1]; - static_assert(std::is_same>::value); - ICHECK_EQ(traced_bar.Get(), "bar"); - ICHECK(traced_bar.GetPath()->PathsEqual(ObjectPath::Root()->ArrayIndex(1))); -} - -TEST(TracedObjectTest, TracedArray_Iterator) { - Array a = {"foo", "bar"}; - auto traced = MakeTraced(a); - - size_t index = 0; - for (const auto& x : traced) { - if (index == 0) { - ICHECK_EQ(x.Get(), "foo"); - ICHECK(x.GetPath()->PathsEqual(ObjectPath::Root()->ArrayIndex(0))); - } else if (index == 1) { - ICHECK_EQ(x.Get(), "bar"); - ICHECK(x.GetPath()->PathsEqual(ObjectPath::Root()->ArrayIndex(1))); - } else { - ICHECK(false); - } - ++index; - } - - ICHECK_EQ(index, 2); -} - -TEST(TracedObjectTest, TracedBasicValue_ApplyFunc) { - auto traced = MakeTraced(123, ObjectPath::Root()->Attr("foo")); - static_assert(std::is_same>::value); - - auto transformed = traced.ApplyFunc([](int x) { return x + 4.0; }); - static_assert(std::is_same>::value); - - ICHECK(transformed.GetPath()->PathsEqual(ObjectPath::Root()->Attr("foo"))); -} diff --git a/tests/cpp/tvmscript_printer_irdocsifier_test.cc b/tests/cpp/tvmscript_printer_irdocsifier_test.cc deleted file mode 100644 index 8c68399df222..000000000000 --- a/tests/cpp/tvmscript_printer_irdocsifier_test.cc +++ /dev/null @@ -1,117 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -#include -#include -#include -#include -#include -#include -#include - -using namespace tvm; -using namespace tvm::script::printer; - -class TestObjectNode : public Object { - public: - void VisitAttrs(AttrVisitor* v) {} - - static constexpr const char* _type_key = "test.script.printer.irdocsifier.TestObject"; - TVM_DECLARE_FINAL_OBJECT_INFO(TestObjectNode, Object); -}; - -class TestObject : public ObjectRef { - public: - TestObject() : ObjectRef(runtime::make_object()) {} - TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(TestObject, ObjectRef, TestObjectNode); -}; - -TVM_REGISTER_NODE_TYPE(TestObjectNode); - -TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch([](TracedObject obj, IRDocsifier p) { - return IdDoc("x"); - }); - -TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("tir", [](TracedObject obj, IRDocsifier p) { - return IdDoc("tir"); - }); - -TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("relax", [](TracedObject obj, IRDocsifier p) { - return IdDoc("relax"); - }); - -TEST(PrinterIRDocsifierTest, AsDoc) { - IRDocsifier p(Map{}); - ObjectPath path = ObjectPath::Root(); - TestObject obj; - - IdDoc doc = p->AsDoc(MakeTraced(obj, path)); - - ICHECK_EQ(doc->name, "x"); -} - -TEST(PrinterIRDocsifierTest, AsExprDoc) { - IRDocsifier p(Map{}); - ObjectPath path = ObjectPath::Root(); - TestObject obj; - - ExprDoc doc = p->AsExprDoc(MakeTraced(obj, path)); - - ICHECK_EQ(Downcast(doc)->name, "x"); -} - -TEST(PrinterIRDocsifierTest, WithDispatchToken) { - IRDocsifier p(Map{}); - TracedObject obj = MakeTraced(TestObject(), ObjectPath::Root()); - - ICHECK_EQ(p->AsDoc(obj)->name, "x"); - - { - auto ctx = p->WithDispatchToken("tir"); - ICHECK_EQ(p->AsDoc(obj)->name, "tir"); - - { - auto ctx = p->WithDispatchToken("relax"); - ICHECK_EQ(p->AsDoc(obj)->name, "relax"); - } - - ICHECK_EQ(p->AsDoc(obj)->name, "tir"); - } - - ICHECK_EQ(p->AsDoc(obj)->name, "x"); -} - -TEST(PrinterIRDocsifierTest, WithFrame) { - IRDocsifier p(Map{}); - TestObject obj; - - { - VarDefFrame frame; - auto ctx = p->WithFrame(frame); - ICHECK_EQ(p->frames.size(), 1); - - p->vars->Define(obj, "x", ObjectPath::Root(), frame); - ICHECK(p->vars->IsVarDefined(obj)); - } - ICHECK_EQ(p->frames.size(), 0); - ICHECK(!p->vars->IsVarDefined(obj)); -} diff --git a/tests/cpp/tvmscript_printer_traced_object_functor_test.cc b/tests/cpp/tvmscript_printer_traced_object_functor_test.cc deleted file mode 100644 index d662ce132405..000000000000 --- a/tests/cpp/tvmscript_printer_traced_object_functor_test.cc +++ /dev/null @@ -1,188 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -#include -#include -#include -#include -#include -#include - -using namespace tvm; -using namespace tvm::script::printer; - -namespace { - -class FooObjectNode : public Object { - public: - void VisitAttrs(AttrVisitor* v) {} - - static constexpr const char* _type_key = "test.TracedObjectFunctor.FooObject"; - TVM_DECLARE_FINAL_OBJECT_INFO(FooObjectNode, Object); -}; - -class FooObject : public ObjectRef { - public: - FooObject() { this->data_ = make_object(); } - TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(FooObject, ObjectRef, FooObjectNode); -}; - -TVM_REGISTER_NODE_TYPE(FooObjectNode); - -class BarObjectNode : public Object { - public: - void VisitAttrs(AttrVisitor* v) {} - - static constexpr const char* _type_key = "test.TracedObjectFunctor.BarObject"; - TVM_DECLARE_FINAL_OBJECT_INFO(BarObjectNode, Object); -}; - -class BarObject : public ObjectRef { - public: - BarObject() { this->data_ = make_object(); } - TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(BarObject, ObjectRef, BarObjectNode); -}; - -TVM_REGISTER_NODE_TYPE(BarObjectNode); - -String ComputeFoo(TracedObject foo) { return "Foo"; } - -} // anonymous namespace - -TEST(TracedObjectFunctorTest, NormalRegistration) { - TracedObjectFunctor functor; - ObjectPath path = ObjectPath::Root(); - - functor.set_dispatch([](TracedObject o) -> String { return "Foo"; }); - functor.set_dispatch([](TracedObject o) -> String { return "Bar"; }); - - ICHECK_EQ(functor("", MakeTraced(FooObject(), path)), "Foo"); - ICHECK_EQ(functor("", MakeTraced(BarObject(), path)), "Bar"); -} - -TEST(TracedObjectFunctorTest, RegistrationWithFunction) { - TracedObjectFunctor functor; - ObjectPath path = ObjectPath::Root(); - - functor.set_dispatch([](TracedObject o) -> String { return "FooLambda"; }); - functor.set_dispatch("tir", ComputeFoo); - - ICHECK_EQ(functor("", MakeTraced(FooObject(), path)), "FooLambda"); - ICHECK_EQ(functor("tir", MakeTraced(FooObject(), path)), "Foo"); -} - -TEST(TracedObjectFunctorTest, RegistrationWithDispatchToken) { - TracedObjectFunctor functor; - ObjectPath path = ObjectPath::Root(); - - functor.set_dispatch([](TracedObject o) -> String { return "Foo"; }); - functor.set_dispatch("tir", - [](TracedObject o) -> String { return "Foo tir"; }); - functor.set_dispatch("relax", - [](TracedObject o) -> String { return "Foo relax"; }); - - ICHECK_EQ(functor("", MakeTraced(FooObject(), path)), "Foo"); - ICHECK_EQ(functor("tir", MakeTraced(FooObject(), path)), "Foo tir"); - ICHECK_EQ(functor("relax", MakeTraced(FooObject(), path)), "Foo relax"); - ICHECK_EQ(functor("xyz", MakeTraced(FooObject(), path)), "Foo"); -} - -TEST(TracedObjectFunctorTest, RegistrationWithPackedFunc) { - TracedObjectFunctor functor; - ObjectPath path = ObjectPath::Root(); - - auto f_default = [](runtime::TVMArgs, runtime::TVMRetValue* ret) { *ret = String("default"); }; - auto f_tir = [](runtime::TVMArgs, runtime::TVMRetValue* ret) { *ret = String("tir"); }; - - functor.set_dispatch("", FooObjectNode::RuntimeTypeIndex(), runtime::PackedFunc(f_default)); - functor.set_dispatch("tir", FooObjectNode::RuntimeTypeIndex(), runtime::PackedFunc(f_tir)); - - ICHECK_EQ(functor("", MakeTraced(FooObject(), path)), "default"); - ICHECK_EQ(functor("tir", MakeTraced(FooObject(), path)), "tir"); -} - -TEST(TracedObjectFunctorTest, ExtraArg) { - TracedObjectFunctor functor; - ObjectPath path = ObjectPath::Root(); - - functor.set_dispatch([](TracedObject o, int x) { return x; }); - functor.set_dispatch([](TracedObject o, int x) { return x + 1; }); - - ICHECK_EQ(functor("", MakeTraced(FooObject(), path), 2), 2); - ICHECK_EQ(functor("", MakeTraced(BarObject(), path), 2), 3); - ICHECK_EQ(functor("tir", MakeTraced(BarObject(), path), 2), 3); -} - -TEST(TracedObjectFunctorTest, RemoveDispatchFunction) { - TracedObjectFunctor functor; - ObjectPath path = ObjectPath::Root(); - - functor.set_dispatch([](TracedObject o) -> String { return "Foo"; }); - functor.set_dispatch("tir", - [](TracedObject o) -> String { return "Foo tir"; }); - - ICHECK_EQ(functor("", MakeTraced(FooObject(), path)), "Foo"); - ICHECK_EQ(functor("tir", MakeTraced(FooObject(), path)), "Foo tir"); - - functor.remove_dispatch("tir", FooObjectNode::RuntimeTypeIndex()); - ICHECK_EQ(functor("tir", MakeTraced(FooObject(), path)), "Foo"); -} - -TEST(TracedObjectFunctorTest, CallWithUnregisteredType) { - TracedObjectFunctor functor; - ObjectPath path = ObjectPath::Root(); - - bool failed = false; - try { - ICHECK_EQ(functor("", MakeTraced(FooObject(), path), 2), 2); - } catch (...) { - failed = true; - } - ASSERT_EQ(failed, true); -} - -TEST(TracedObjectFunctorTest, DuplicateRegistration_WithoutToken) { - TracedObjectFunctor functor; - ObjectPath path = ObjectPath::Root(); - - functor.set_dispatch([](TracedObject o, int x) { return x; }); - - bool failed = false; - try { - functor.set_dispatch([](TracedObject o, int x) { return x; }); - } catch (...) { - failed = true; - } - ASSERT_EQ(failed, true); -} - -TEST(TracedObjectFunctorTest, DuplicateRegistration_WithToken) { - TracedObjectFunctor functor; - ObjectPath path = ObjectPath::Root(); - - functor.set_dispatch("tir", [](TracedObject o, int x) { return x; }); - - bool failed = false; - try { - functor.set_dispatch("tir", [](TracedObject o, int x) { return x; }); - } catch (...) { - failed = true; - } - ASSERT_EQ(failed, true); -} diff --git a/tests/cpp/tvmscript_printer_var_table_test.cc b/tests/cpp/tvmscript_printer_var_table_test.cc deleted file mode 100644 index b447c81ac0b8..000000000000 --- a/tests/cpp/tvmscript_printer_var_table_test.cc +++ /dev/null @@ -1,158 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -#include -#include -#include -#include -#include -#include -#include - -using namespace tvm; -using namespace tvm::script::printer; - -TEST(PrinterVarTableTest, Define) { - VarTable vars; - MetadataFrame frame; - tir::Var x("x"); - ObjectPath object_path = ObjectPath::Root(); - - IdDoc doc = vars->Define(x, "x", object_path, frame); - - ICHECK_EQ(doc->name, "x"); - - IdDoc second_doc = Downcast(vars->GetVarDoc(x, object_path).value()); - - ICHECK_EQ(second_doc->name, "x"); -} - -TEST(PrinterVarTableTest, DefineByDoc) { - VarTable vars; - MetadataFrame frame; - tir::Var x("x"); - ObjectPath object_path = ObjectPath::Root(); - - auto doc_factory = []() { return LiteralDoc::Str("x"); }; - - vars->DefineByDoc(x, doc_factory, frame); - - ExprDoc doc = vars->GetVarDoc(x, object_path).value(); - - ICHECK_EQ(Downcast(Downcast(doc)->value), "x"); -} - -TEST(PrinterVarTableTest, GetVarDocWithUnknownVariable) { - VarTable vars; - MetadataFrame frame; - tir::Var x("x"); - tir::Var y("y"); - ObjectPath object_path = ObjectPath::Root(); - - Doc doc = vars->Define(x, "x", object_path, frame); - ICHECK(!vars->GetVarDoc(y, object_path).defined()); -} - -TEST(PrinterVarTableTest, GetVarDocWithObjectPath) { - VarTable vars; - MetadataFrame frame; - tir::Var x("x"); - ObjectPath object_path = ObjectPath::Root(); - ObjectPath second_object_path = ObjectPath::Root()->Attr("x"); - - IdDoc doc = vars->Define(x, "x", object_path, frame); - ICHECK_EQ(doc->source_paths[0], object_path); - ICHECK_EQ(doc->source_paths.size(), 1); - - Doc second_doc = vars->GetVarDoc(x, second_object_path).value(); - ICHECK_EQ(second_doc->source_paths[0], second_object_path); - ICHECK_EQ(second_doc->source_paths.size(), 1); -} - -TEST(PrinterVarTableTest, IsVarDefined) { - VarTable vars; - MetadataFrame frame; - tir::Var x("x"); - tir::Var y("y"); - ObjectPath object_path = ObjectPath::Root(); - - vars->Define(x, "x", object_path, frame); - ICHECK(vars->IsVarDefined(x)); - ICHECK(!vars->IsVarDefined(y)); -} - -TEST(PrinterVarTableTest, VarRemovedAfterFrameOutOfScope) { - VarTable vars; - MetadataFrame frame; - tir::Var x("x"); - ObjectPath object_path = ObjectPath::Root(); - - vars->Define(x, "x", object_path, frame); - ICHECK(vars->IsVarDefined(x)); - - frame->ExitWithScope(); - ICHECK(!vars->IsVarDefined(x)); -} - -TEST(PrinterVarTableTest, DefineDuplicateName) { - VarTable vars; - MetadataFrame frame; - tir::Var x("x"); - tir::Var y("y"); - ObjectPath object_path = ObjectPath::Root(); - - IdDoc x_doc = vars->Define(x, "x", object_path, frame); - IdDoc y_doc = vars->Define(y, "x", object_path, frame); - - ICHECK_NE(x_doc->name, y_doc->name); -} - -TEST(PrinterVarTableTest, DefineDuplicateVariable) { - VarTable vars; - MetadataFrame frame; - tir::Var x("x"); - ObjectPath object_path = ObjectPath::Root(); - - vars->Define(x, "x", object_path, frame); - - bool failed = false; - try { - vars->Define(x, "x", object_path, frame); - } catch (...) { - failed = true; - } - ASSERT_EQ(failed, true); -} - -TEST(PrinterVarTableTest, DefineByDocWithIdDoc) { - VarTable vars; - MetadataFrame frame; - tir::Var x("x"); - ObjectPath object_path = ObjectPath::Root(); - - bool failed = false; - try { - // User has to use `Define` if variable needs to be mapped to IdDoc - vars->DefineByDoc( - x, []() { return IdDoc("x"); }, frame); - } catch (...) { - failed = true; - } - ASSERT_EQ(failed, true); -} diff --git a/tests/python/unittest/test_tvmscript_printer_entry_point.py b/tests/python/unittest/test_tvmscript_printer_entry_point.py deleted file mode 100644 index 208386dbdd4a..000000000000 --- a/tests/python/unittest/test_tvmscript_printer_entry_point.py +++ /dev/null @@ -1,30 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -import pytest - -from tvm.error import TVMError -from tvm.script.printer import script -from tvm.tir import FloatImm - - -def test_as_script_unknown_ir(): - ir_node = FloatImm("float32", 1.0) - - with pytest.raises(TVMError) as e: - script(ir_node, "test_xyz", {}) - - assert "test_xyz" in str(e.value) diff --git a/tests/python/unittest/test_tvmscript_printer_frame.py b/tests/python/unittest/test_tvmscript_printer_frame.py deleted file mode 100644 index bd98d6445644..000000000000 --- a/tests/python/unittest/test_tvmscript_printer_frame.py +++ /dev/null @@ -1,60 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from tvm.script.printer.frame import MetadataFrame - - -def test_frame_add_callback(): - frame = MetadataFrame() - - flag = 0 - - def callback1(): - nonlocal flag - flag += 1 - - def callback2(): - nonlocal flag - flag += 5 - - frame.add_exit_callback(callback1) - with frame: - frame.add_exit_callback(callback2) - assert flag == 0 - - assert flag == 6 - - -def test_frame_clear_callbacks_after_exit(): - frame = MetadataFrame() - - flag = 0 - - def callback(): - nonlocal flag - flag += 1 - - frame.add_exit_callback(callback) - - with frame: - pass - - assert flag == 1 - - with frame: - pass - - assert flag == 1 diff --git a/tests/python/unittest/test_tvmscript_printer_irdocsifier.py b/tests/python/unittest/test_tvmscript_printer_irdocsifier.py deleted file mode 100644 index d9d552ce4b9f..000000000000 --- a/tests/python/unittest/test_tvmscript_printer_irdocsifier.py +++ /dev/null @@ -1,123 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -import pytest - -from tvm.runtime import ObjectPath -from tvm.script.printer.doc import IdDoc -from tvm.script.printer.frame import MetadataFrame, VarDefFrame -from tvm.script.printer.ir_docsifier import IRDocsifier, RootNodeContainer -from tvm.tir import Var - - -@pytest.fixture -def ir_docsifier(): - """ - Creates an IRDocsifier instance with a special dispatch token. - """ - _ir_docsifier = IRDocsifier({}) - with _ir_docsifier.dispatch_token(f"{__file__}"): - yield _ir_docsifier - - -def _get_id_doc_printer(id_name): - def printer(obj, object_path, ir_docsifier): # pylint: disable=unused-argument - return IdDoc(id_name) - - return printer - - -def _root_dispatch_function(obj, ir_docsifier): - doc = ir_docsifier.as_doc(obj, ObjectPath.root()) - doc.source_paths = [ObjectPath.root().attr("irdocsifier_test")] - return doc - - -# Because the dispatch table is global, tests should only set dispatch function under -# unique dispatch token. -IRDocsifier.set_dispatch(Var, _get_id_doc_printer("x"), f"{__file__}") -IRDocsifier.set_root_dispatch(f"{__file__}", _root_dispatch_function) - - -def test_set_dispatch(ir_docsifier): - IRDocsifier.set_dispatch(Var, _get_id_doc_printer("x2"), f"{__file__}-2") - with ir_docsifier.dispatch_token(f"{__file__}-2"): - doc = ir_docsifier.as_doc(Var("x", dtype="int8"), ObjectPath.root()) - assert doc.name == "x2" - - doc = ir_docsifier.as_doc(Var("x", dtype="int8"), ObjectPath.root()) - assert doc.name == "x" - - -def test_set_root_dispatch(ir_docsifier): - doc = ir_docsifier.as_doc(RootNodeContainer(Var("x", dtype="int8")), ObjectPath.root()) - assert ObjectPath.root().attr("irdocsifier_test") in doc.source_paths - - -def test_as_doc(ir_docsifier): - object_path = ObjectPath.root() - doc = ir_docsifier.as_doc(Var("x", "int8"), ObjectPath.root()) - assert doc.name == "x" - assert list(doc.source_paths) == [object_path] - - -def test_with_dispatch_token(ir_docsifier): - initial_token_count = len(ir_docsifier.dispatch_tokens) - - with ir_docsifier.dispatch_token("tir"): - assert len(ir_docsifier.dispatch_tokens) == initial_token_count + 1 - - assert len(ir_docsifier.dispatch_tokens) == initial_token_count - - -def test_with_frame(ir_docsifier): - initial_frame_count = len(ir_docsifier.frames) - - frame = VarDefFrame() - is_callback_called = False - - def callback(): - nonlocal is_callback_called - is_callback_called = True - - frame.add_exit_callback(callback) - - with ir_docsifier.frame(frame): - assert len(ir_docsifier.frames) == initial_frame_count + 1 - assert not is_callback_called - - assert len(ir_docsifier.frames) == initial_frame_count - assert is_callback_called - - -def test_get_frame(ir_docsifier): - with ir_docsifier.frame(VarDefFrame()) as frame_a: - assert ir_docsifier.get_frame(MetadataFrame) is None - assert ir_docsifier.get_frame(VarDefFrame) == frame_a - - with ir_docsifier.frame(VarDefFrame()) as frame_b: - assert ir_docsifier.get_frame(MetadataFrame) is None - assert ir_docsifier.get_frame(VarDefFrame) == frame_b - - with ir_docsifier.frame(MetadataFrame()) as frame_c: - assert ir_docsifier.get_frame(MetadataFrame) == frame_c - assert ir_docsifier.get_frame(VarDefFrame) == frame_b - - assert ir_docsifier.get_frame(MetadataFrame) is None - assert ir_docsifier.get_frame(VarDefFrame) == frame_b - - assert ir_docsifier.get_frame(MetadataFrame) is None - assert ir_docsifier.get_frame(VarDefFrame) == frame_a diff --git a/tests/python/unittest/test_tvmscript_printer_var_table.py b/tests/python/unittest/test_tvmscript_printer_var_table.py deleted file mode 100644 index eab63a08ddad..000000000000 --- a/tests/python/unittest/test_tvmscript_printer_var_table.py +++ /dev/null @@ -1,89 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -This file tests the FFI binding of script.printer.VarTable. -These only make sure parameter can be passed to the C++ functions -correctly. The test for the functionality of VarTable is in C++. -""" - -from tvm.runtime import ObjectPath -from tvm.script.printer.doc import LiteralDoc -from tvm.script.printer.frame import VarDefFrame -from tvm.script.printer.var_table import VarTable -from tvm.tir import Var - - -def test_define(): - var_table = VarTable() - var_name = "a" - var_obj = Var(var_name, dtype="int32") - object_path = ObjectPath.root().attr("a") - frame = VarDefFrame() - - id_doc = var_table.define(var_obj, var_name, object_path, frame) - - assert id_doc.name == "a" - assert list(id_doc.source_paths) == [object_path] - - id_doc = var_table.get_var_doc(var_obj, object_path) - - assert id_doc.name == "a" - assert list(id_doc.source_paths) == [object_path] - - -def test_define_by_doc(): - var_table = VarTable() - var_name = "a" - var_obj = Var(var_name, dtype="int32") - object_path = ObjectPath.root().attr("a") - frame = VarDefFrame() - - var_table.define_by_doc(var_obj, lambda: LiteralDoc(var_name), frame) - - var_doc = var_table.get_var_doc(var_obj, object_path) - - assert isinstance(var_doc, LiteralDoc) - assert var_doc.value == var_name - assert list(var_doc.source_paths) == [object_path] - - -def test_is_var_defined(): - var_table = VarTable() - a = Var("a", dtype="int32") - object_path = ObjectPath.root().attr("a") - frame = VarDefFrame() - - var_table.define(a, "a", object_path, frame) - - assert var_table.is_var_defined(a) - assert a in var_table - - -def test_var_out_of_scope(): - var_table = VarTable() - var_name = "a" - var_obj = Var(var_name, dtype="int32") - object_path = ObjectPath.root().attr("a") - frame = VarDefFrame() - - var_table.define(var_obj, var_name, object_path, frame) - - with frame: - assert var_obj in var_table - - assert var_obj not in var_table - assert var_table.get_var_doc(var_obj, object_path) is None