diff --git a/include/tvm/node/repr_printer.h b/include/tvm/node/repr_printer.h index 2a2d0bf3fb05..3293b43564cc 100644 --- a/include/tvm/node/repr_printer.h +++ b/include/tvm/node/repr_printer.h @@ -63,6 +63,8 @@ class ReprLegacyPrinter { TVM_DLL void Print(const ObjectRef& node); /*! \brief Print indent to the stream */ TVM_DLL void PrintIndent(); + /*! \brief Could the LegacyPrinter dispatch the node */ + TVM_DLL static bool CanDispatch(const ObjectRef& node); /*! \brief Return the ostream it maintains */ TVM_DLL std::ostream& Stream() const; // Allow registration to be printer. diff --git a/include/tvm/node/script_printer.h b/include/tvm/node/script_printer.h index c51d6a52b910..e36e37bdf1a0 100644 --- a/include/tvm/node/script_printer.h +++ b/include/tvm/node/script_printer.h @@ -35,6 +35,10 @@ namespace tvm { class PrinterConfigNode : public Object { public: + /*! \brief A stack that tracks the names of the binding hierarchy */ + Array binding_names = {}; + /*! \brief Whether or not to show metadata. */ + bool show_meta = false; /*! \brief The prefix of IR nodes */ std::string ir_prefix = "I"; /*! \brief The prefix of TIR nodes */ @@ -71,6 +75,8 @@ class PrinterConfigNode : public Object { Map obj_to_annotate = Map(); void VisitAttrs(AttrVisitor* v) { + v->Visit("binding_names", &binding_names); + v->Visit("show_meta", &show_meta); v->Visit("ir_prefix", &ir_prefix); v->Visit("buffer_dtype", &buffer_dtype); v->Visit("int_dtype", &int_dtype); diff --git a/include/tvm/script/printer/ir_docsifier.h b/include/tvm/script/printer/ir_docsifier.h index 9225e7de3369..156daebf001f 100644 --- a/include/tvm/script/printer/ir_docsifier.h +++ b/include/tvm/script/printer/ir_docsifier.h @@ -141,10 +141,10 @@ class IRDocsifierNode : public Object { * when converting IR node object to Doc. */ Array dispatch_tokens; - /*! \brief The IRModule to be docsifier is handling */ - Optional mod; /*! \brief Mapping from a var to its info */ std::unordered_map obj2info; + /*! \brief Metadata printing */ + std::unordered_map> metadata; /*! \brief The variable names used already */ std::unordered_set defined_names; /*! \brief Common prefixes of variable usages */ @@ -155,8 +155,8 @@ class IRDocsifierNode : public Object { void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("frames", &frames); v->Visit("dispatch_tokens", &dispatch_tokens); - v->Visit("mod", &mod); // `obj2info` is not visited + // `metadata` is not visited // `defined_names` is not visited // `common_prefix` is not visited // `ir_usage` is not visited @@ -204,7 +204,8 @@ class IRDocsifierNode : public Object { * \return The doc for variable, if it exists in the table. Otherwise it returns NullOpt. */ Optional GetVarDoc(const ObjectRef& obj) const; - + /*! \brief Add a TVM object to the metadata section*/ + ExprDoc AddMetadata(const ObjectRef& obj); /*! * \brief Check if a variable exists in the table. * \param obj The variable object. diff --git a/python/tvm/runtime/script_printer.py b/python/tvm/runtime/script_printer.py index 6838865490ad..a2f4cdc33135 100644 --- a/python/tvm/runtime/script_printer.py +++ b/python/tvm/runtime/script_printer.py @@ -15,9 +15,9 @@ # specific language governing permissions and limitations # under the License. """Configuration of TVMScript printer""" -from typing import List, Dict, Optional +from typing import Dict, List, Optional, Sequence -from tvm._ffi import register_object +from tvm._ffi import get_global_func, register_object from tvm.runtime import Object from . import _ffi_node_api @@ -28,6 +28,8 @@ class PrinterConfig(Object): """Configuration of TVMScript printer""" + binding_names: Sequence[str] + show_meta: bool ir_prefix: str tir_prefix: str relax_prefix: str @@ -47,6 +49,8 @@ class PrinterConfig(Object): def __init__( self, *, + name: Optional[str] = None, + show_meta: bool = False, ir_prefix: str = "I", tir_prefix: str = "T", relax_prefix: str = "R", @@ -65,25 +69,29 @@ def __init__( ) -> None: if num_context_lines is None: num_context_lines = -1 + cfg = { + "show_meta": show_meta, + "ir_prefix": ir_prefix, + "tir_prefix": tir_prefix, + "relax_prefix": relax_prefix, + "buffer_dtype": buffer_dtype, + "int_dtype": int_dtype, + "float_dtype": float_dtype, + "verbose_expr": verbose_expr, + "indent_spaces": indent_spaces, + "print_line_numbers": print_line_numbers, + "num_context_lines": num_context_lines, + "syntax_sugar": syntax_sugar, + "path_to_underline": path_to_underline, + "path_to_annotate": path_to_annotate, + "obj_to_underline": obj_to_underline, + "obj_to_annotate": obj_to_annotate, + } + + if name is not None: + cfg["name"] = name self.__init_handle_by_constructor__( - _ffi_node_api.PrinterConfig, # type: ignore # pylint: disable=no-member - { - "ir_prefix": ir_prefix, - "tir_prefix": tir_prefix, - "relax_prefix": relax_prefix, - "buffer_dtype": buffer_dtype, - "int_dtype": int_dtype, - "float_dtype": float_dtype, - "verbose_expr": verbose_expr, - "indent_spaces": indent_spaces, - "print_line_numbers": print_line_numbers, - "num_context_lines": num_context_lines, - "syntax_sugar": syntax_sugar, - "path_to_underline": path_to_underline, - "path_to_annotate": path_to_annotate, - "obj_to_underline": obj_to_underline, - "obj_to_annotate": obj_to_annotate, - }, + _ffi_node_api.PrinterConfig, cfg # type: ignore # pylint: disable=no-member ) @@ -91,12 +99,19 @@ def _script(obj: Object, config: PrinterConfig) -> str: return _ffi_node_api.TVMScriptPrinterScript(obj, config) # type: ignore # pylint: disable=no-member +def _relax_script(obj: Object, config: PrinterConfig) -> str: + func = get_global_func("script.printer.ReprPrintRelax") + return func(obj, config) + + class Scriptable: """A base class that enables the script() and show() method.""" def script( self, *, + name: Optional[str] = None, + show_meta: bool = False, ir_prefix: str = "I", tir_prefix: str = "T", relax_prefix: str = "R", @@ -117,6 +132,10 @@ def script( Parameters ---------- + name : Optional[str] = None + The name of the object + show_meta : bool = False + Whether to print the meta data of the object ir_prefix : str = "I" The prefix of AST nodes from tvm.ir tir_prefix : str = "T" @@ -156,6 +175,52 @@ def script( return _script( self, PrinterConfig( + name=name, + show_meta=show_meta, + ir_prefix=ir_prefix, + tir_prefix=tir_prefix, + relax_prefix=relax_prefix, + buffer_dtype=buffer_dtype, + int_dtype=int_dtype, + float_dtype=float_dtype, + verbose_expr=verbose_expr, + indent_spaces=indent_spaces, + print_line_numbers=print_line_numbers, + num_context_lines=num_context_lines, + syntax_sugar=syntax_sugar, + path_to_underline=path_to_underline, + path_to_annotate=path_to_annotate, + obj_to_underline=obj_to_underline, + obj_to_annotate=obj_to_annotate, + ), + ) + + def _relax_script( + self, + *, + name: Optional[str] = None, + show_meta: bool = False, + ir_prefix: str = "I", + tir_prefix: str = "T", + relax_prefix: str = "R", + buffer_dtype: str = "float32", + int_dtype: str = "int32", + float_dtype: str = "void", + verbose_expr: bool = False, + indent_spaces: int = 4, + print_line_numbers: bool = False, + num_context_lines: int = -1, + syntax_sugar: bool = True, + path_to_underline: Optional[List[ObjectPath]] = None, + path_to_annotate: Optional[Dict[ObjectPath, str]] = None, + obj_to_underline: Optional[List[Object]] = None, + obj_to_annotate: Optional[Dict[Object, str]] = None, + ) -> str: + return _relax_script( + self, + PrinterConfig( + name=name, + show_meta=show_meta, ir_prefix=ir_prefix, tir_prefix=tir_prefix, relax_prefix=relax_prefix, @@ -179,6 +244,8 @@ def show( style: Optional[str] = None, black_format: bool = True, *, + name: Optional[str] = None, + show_meta: bool = False, ir_prefix: str = "I", tir_prefix: str = "T", relax_prefix: str = "R", @@ -204,6 +271,10 @@ def show( `tvm.script.highlight.cprint` for more details. black_format: bool If true (default), use the formatter Black to format the TVMScript + name : Optional[str] = None + The name of the object + show_meta : bool = False + Whether to print the meta data of the object ir_prefix : str = "I" The prefix of AST nodes from tvm.ir tir_prefix : str = "T" @@ -241,6 +312,8 @@ def show( cprint( self.script( + name=name, + show_meta=show_meta, ir_prefix=ir_prefix, tir_prefix=tir_prefix, relax_prefix=relax_prefix, diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index d4b280a37fa3..0e76e3d86d6e 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -34,7 +34,7 @@ # pylint: disable=unused-import from tvm.target.codegen import llvm_lookup_intrinsic_id -from tvm.tir import Buffer, BufferRegion, PrimExpr +from tvm.tir import Buffer, BufferRegion, IndexMap, PrimExpr from tvm.tir import op as _tir_op from tvm.tir import type_annotation @@ -1522,6 +1522,15 @@ def comm_reducer(combiner: Callable, identity: List[PrimExpr]) -> CommReducer: return CommReducer(args[: num_args // 2], args[num_args // 2 :], res, identity) +def index_map( + mapping: Callable, + *, + inverse_index_map: Optional[Callable] = None, +) -> IndexMap: + """Create a TIR Index mapping""" + return IndexMap.from_func(mapping, inverse_index_map=inverse_index_map) + + def target(target_config: Union[Dict, str]) -> Target: """ Create a target @@ -1824,6 +1833,7 @@ def wrapped(*args, **kwargs): "max", "iter_var", "comm_reducer", + "index_map", "target", "buffer_var", "abs", diff --git a/src/node/repr_printer.cc b/src/node/repr_printer.cc index 63bba67dd5f2..e024cb361046 100644 --- a/src/node/repr_printer.cc +++ b/src/node/repr_printer.cc @@ -58,10 +58,20 @@ void ReprLegacyPrinter::Print(const ObjectRef& node) { } else if (f.can_dispatch(node)) { f(node, this); } else { - stream << node; // Use ReprPrinter + try { + stream << node; // Use ReprPrinter + } catch (const tvm::Error& e) { + LOG(WARNING) << "ReprPrinter fails"; + stream << node->GetTypeKey() << '(' << node.get() << ')'; + } } } +bool ReprLegacyPrinter::CanDispatch(const ObjectRef& node) { + static const FType& f = vtable(); + return !node.defined() || f.can_dispatch(node); +} + void ReprLegacyPrinter::PrintIndent() { for (int i = 0; i < indent; ++i) { stream << ' '; diff --git a/src/node/script_printer.cc b/src/node/script_printer.cc index fcd3c53d026c..071f427a7230 100644 --- a/src/node/script_printer.cc +++ b/src/node/script_printer.cc @@ -37,6 +37,12 @@ std::string TVMScriptPrinter::Script(const ObjectRef& node, const Optional config_dict) { runtime::ObjectPtr n = make_object(); + if (auto v = config_dict.Get("name")) { + n->binding_names.push_back(Downcast(v)); + } + if (auto v = config_dict.Get("show_meta")) { + n->show_meta = Downcast(v)->value; + } if (auto v = config_dict.Get("ir_prefix")) { n->ir_prefix = Downcast(v); } diff --git a/src/relay/ir/function.cc b/src/relay/ir/function.cc index 5d743d521777..b73340df30ac 100644 --- a/src/relay/ir/function.cc +++ b/src/relay/ir/function.cc @@ -262,6 +262,14 @@ TVM_REGISTER_GLOBAL("relay.ir.FuncWithAttr") return NullOpt; }); +TVM_REGISTER_GLOBAL("relay.ir.FuncWithoutAttr") + .set_body_typed([](BaseFunc func, String key) -> Optional { + if (func->IsInstance()) { + return WithoutAttr(Downcast(std::move(func)), key); + } + return NullOpt; + }); + TVM_REGISTER_NODE_TYPE(FunctionNode); TVM_REGISTER_GLOBAL("relay.ir.Function") diff --git a/src/script/printer/ir/ir.cc b/src/script/printer/ir/ir.cc index 7f7857dba671..190669aa7a6c 100644 --- a/src/script/printer/ir/ir.cc +++ b/src/script/printer/ir/ir.cc @@ -42,18 +42,24 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) } return lhs_name < rhs_name; }); - ICHECK(!d->mod.defined()); - d->mod = mod; - { - With f(d); - (*f)->AddDispatchToken(d, "ir"); - for (const auto& kv : functions) { - GlobalVar gv = kv.first; - BaseFunc func = kv.second; - (*f)->stmts.push_back(d->AsDoc(func, p->Attr("functions")->MapValue(gv))); + With f(d); + (*f)->AddDispatchToken(d, "ir"); + for (const auto& kv : functions) { + GlobalVar gv = kv.first; + BaseFunc func = kv.second; + d->cfg->binding_names.push_back(gv->name_hint); + Doc doc = d->AsDoc(func, p->Attr("functions")->MapValue(gv)); + d->cfg->binding_names.pop_back(); + if (const auto* stmt_block = doc.as()) { + (*f)->stmts.push_back(stmt_block->stmts.back()); + } else if (const auto* stmt = doc.as()) { + (*f)->stmts.push_back(GetRef(stmt)); + } else { + (*f)->stmts.push_back(Downcast(doc)); } - return ClassDoc(IdDoc("Module"), {IR(d, "ir_module")}, (*f)->stmts); } + return HeaderWrapper(d, ClassDoc(IdDoc(GetBindingName(d).value_or("Module")), + {IR(d, "ir_module")}, (*f)->stmts)); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) @@ -119,9 +125,7 @@ std::string ReprPrintIRModule(const ObjectRef& mod, const PrinterConfig& cfg) { return s.value(); } } - IRDocsifier d(cfg); - Doc doc = HeaderWrapper(d, d->AsDoc(mod, ObjectPath::Root())); - return DocToPythonScript(doc, cfg); + return ReprPrintIR(mod, cfg); } TVM_SCRIPT_REPR(TypeVarNode, ReprPrintIR); diff --git a/src/script/printer/ir/utils.h b/src/script/printer/ir/utils.h index a05030516f3f..6e95cd644b9f 100644 --- a/src/script/printer/ir/utils.h +++ b/src/script/printer/ir/utils.h @@ -34,12 +34,6 @@ namespace tvm { namespace script { namespace printer { -/*! \brief Creates the IR common prefix, which is by default `I` */ -inline ExprDoc IR(const IRDocsifier& d, const String& attr) { - d->ir_usage.insert("ir"); - return IdDoc(d->cfg->ir_prefix)->Attr(attr); -} - class IRFrameNode : public FrameNode { public: void VisitAttrs(AttrVisitor* v) { FrameNode::VisitAttrs(v); } @@ -65,9 +59,7 @@ inline std::string ReprPrintIR(const ObjectRef& obj, const PrinterConfig& cfg) { IRDocsifier d(cfg); With f(d); (*f)->AddDispatchToken(d, "ir"); - std::ostringstream oss; - oss << Docsify(obj, d, *f, cfg); - return oss.str(); + return Docsify(obj, d, *f, cfg); } } // namespace printer diff --git a/src/script/printer/ir_docsifier.cc b/src/script/printer/ir_docsifier.cc index 5a0d2bd6bbe0..936534480ffb 100644 --- a/src/script/printer/ir_docsifier.cc +++ b/src/script/printer/ir_docsifier.cc @@ -21,26 +21,16 @@ #include #include +#include "./utils.h" + namespace tvm { namespace script { namespace printer { -String GenerateUniqueName(std::string name_hint, std::unordered_set* defined_names) { - for (char& c : name_hint) { - if (c != '_' && !std::isalnum(c)) { - c = '_'; - } - } - std::string name = name_hint; - for (int i = 1; !defined_names->insert(name).second; ++i) { - name = name_hint + "_" + std::to_string(i); - } - return name; -} - IdDoc IRDocsifierNode::Define(const ObjectRef& obj, const Frame& frame, const String& name_hint) { ICHECK(obj2info.find(obj) == obj2info.end()) << "Duplicated object: " << obj; - String name = GenerateUniqueName(name_hint, &this->defined_names); + String name = GenerateUniqueName(name_hint, this->defined_names); + this->defined_names.insert(name); DocCreator doc_factory = [name]() { return IdDoc(name); }; obj2info.insert({obj, VariableInfo{std::move(doc_factory), name}}); IdDoc def_doc(name); @@ -62,6 +52,17 @@ Optional IRDocsifierNode::GetVarDoc(const ObjectRef& obj) const { return it->second.creator(); } +ExprDoc IRDocsifierNode::AddMetadata(const ObjectRef& obj) { + ICHECK(obj.defined()) << "TypeError: Cannot add nullptr to metadata"; + String key = obj->GetTypeKey(); + Array& array = metadata[key]; + int index = array.size(); + array.push_back(obj); + return IdDoc("metadata") // + [{LiteralDoc::Str(key, NullOpt)}] // + [{LiteralDoc::Int(index, NullOpt)}]; +} + bool IRDocsifierNode::IsVarDefined(const ObjectRef& obj) const { return obj2info.count(obj); } void IRDocsifierNode::RemoveVar(const ObjectRef& obj) { diff --git a/src/script/printer/tir/expr.cc b/src/script/printer/tir/expr.cc index ab91764b6a0b..cc37f46e6036 100644 --- a/src/script/printer/tir/expr.cc +++ b/src/script/printer/tir/expr.cc @@ -37,12 +37,14 @@ Doc PrintVar(const tir::Var& var, const ObjectPath& var_p, const IRDocsifier& d) ExprDoc rhs = TIR(d, "var")->Call({LiteralDoc::DataType(var->dtype, var_p->Attr("dtype"))}); opt_f.value()->stmts.push_back(AssignDoc(lhs, rhs, NullOpt)); } + } else { + LOG(WARNING) << "Didn't find variable definition for: " << var->name_hint; } } if (Optional doc = d->GetVarDoc(var)) { return doc.value(); } - LOG(FATAL) << "IndexError: Variable is not defined in the environment: " << var; + LOG(FATAL) << "IndexError: Variable is not defined in the environment: " << var->name_hint; } TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) // @@ -157,6 +159,37 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) return TIR(d, "comm_reducer")->Call({lambda, id}); }); +LambdaDoc PrintIndexMap(const ObjectRef& map, const Array& vs, const ObjectPath& vs_p, + const Array& es, const ObjectPath& es_p, const IRDocsifier& d) { + With f(d, map); + Array vars; + for (int i = 0, l = vs.size(); i < l; ++i) { + vars.push_back(Downcast(DefineVar(vs[i], *f, d))); + } + Array exprs; + for (int i = 0, l = es.size(); i < l; ++i) { + exprs.push_back(d->AsDoc(es[i], es_p->ArrayIndex(i))); + } + return LambdaDoc(vars, TupleDoc(exprs)); +} + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch( // + "", [](tir::IndexMap m, ObjectPath m_p, IRDocsifier d) -> Doc { + LambdaDoc map = PrintIndexMap(m, m->initial_indices, m_p->Attr("initial_indices"), + m->final_indices, m_p->Attr("final_indices"), d); + if (m->inverse_index_map.defined()) { + tir::IndexMap inverse = Downcast(m->inverse_index_map); + LambdaDoc inv = PrintIndexMap(inverse, inverse->initial_indices, + m_p->Attr("inverse_index_map")->Attr("initial_indices"), + inverse->final_indices, + m_p->Attr("inverse_index_map")->Attr("final_indices"), d); + return TIR(d, "index_map")->Call({map}, {"inverse_index_map"}, {inv}); + } else { + return TIR(d, "index_map")->Call({map}); + } + }); + TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch("", [](tir::Let let, ObjectPath p, IRDocsifier d) -> Doc { return TIR(d, "let")->Call({ @@ -250,17 +283,26 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) return TIR(d, OpString)->Call({a, b}); \ }); -#define TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(NodeType, OpString, OpKind) \ - TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) \ - .set_dispatch( \ - "", [](tir::NodeType node, ObjectPath p, IRDocsifier d) -> Doc { \ - ExprDoc a = d->AsDoc(node->a, p->Attr("a")); \ - ExprDoc b = d->AsDoc(node->b, p->Attr("b")); \ - if (a->IsInstance() && b->IsInstance()) { \ - return TIR(d, OpString)->Call({a, b}); \ - } \ - return OperationDoc(OperationDocNode::Kind::OpKind, {a, b}); \ - }); +bool IsNumber(const ExprDoc& e) { + if (const auto* n = e.as()) { + if (n->value.defined()) { + return n->value->IsInstance() || n->value->IsInstance(); + } + } + return false; +} + +#define TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(NodeType, OpString, OpKind) \ + TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) \ + .set_dispatch("", \ + [](tir::NodeType node, ObjectPath p, IRDocsifier d) -> Doc { \ + ExprDoc a = d->AsDoc(node->a, p->Attr("a")); \ + ExprDoc b = d->AsDoc(node->b, p->Attr("b")); \ + if (IsNumber(a) && IsNumber(b)) { \ + return TIR(d, OpString)->Call({a, b}); \ + } \ + return OperationDoc(OperationDocNode::Kind::OpKind, {a, b}); \ + }); TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(Add, "Add", kAdd); TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(Sub, "Sub", kSub); @@ -314,6 +356,7 @@ TVM_SCRIPT_REPR(tir::LetNode, ReprPrintTIR); TVM_SCRIPT_REPR(tir::CallNode, ReprPrintTIR); TVM_SCRIPT_REPR(tir::ShuffleNode, ReprPrintTIR); TVM_SCRIPT_REPR(tir::CommReducerNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tir::IndexMapNode, ReprPrintTIR); TVM_SCRIPT_REPR(tir::AnyNode, ReprPrintTIR); TVM_SCRIPT_REPR(tir::ReduceNode, ReprPrintTIR); TVM_SCRIPT_REPR(tir::LoadNode, ReprPrintTIR); diff --git a/src/script/printer/tir/function.cc b/src/script/printer/tir/function.cc index c3f9244962d6..6a4df34a3a7a 100644 --- a/src/script/printer/tir/function.cc +++ b/src/script/printer/tir/function.cc @@ -24,18 +24,6 @@ namespace tvm { namespace script { namespace printer { -String FindFunctionName(const IRDocsifier& d, const tir::PrimFunc& f) { - if (!d->mod.defined()) { - return "main"; - } - for (const auto& kv : d->mod.value()->functions) { - if (kv.second.same_as(f)) { - return kv.first->name_hint; - } - } - return "main"; -} - bool IsSimpleBuffer(const tir::Buffer& buf) { if (!buf->strides.empty()) { return false; @@ -78,7 +66,11 @@ int CountVarOccurrence(const tir::PrimFunc& f, const tir::Var& v) { TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch("", [](tir::PrimFunc func, ObjectPath p, IRDocsifier d) -> Doc { - With frame(MakeDispatchFrame(d, func, func)); + With f(d, func); + (*f)->AddDispatchToken(d, "tir"); + d->SetCommonPrefix(func, [](const ObjectRef& obj) { + return obj->IsInstance() || obj->IsInstance(); + }); int n_args = func->params.size(); std::unordered_map buffer_data_counter; for (const auto& pair : func->buffer_map) { @@ -100,18 +92,18 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) tir::Buffer buffer = func->buffer_map[var]; if (IsSimpleBuffer(buffer) && buffer_data_counter.at(buffer->data.get()) == 1) { ObjectPath buffer_p = p->Attr("buffer_map")->MapValue(var); - args.push_back(AssignDoc(DefineBuffer(buffer, *frame, d), NullOpt, - BufferAttn(buffer, buffer_p, *frame, d))); + args.push_back(AssignDoc(DefineBuffer(buffer, *f, d), NullOpt, + BufferAttn(buffer, buffer_p, *f, d))); buffer_inlined.insert(buffer.get()); continue; } } ExprDoc a = d->AsDoc(var->type_annotation, var_p->Attr("type_annotation")); - args.push_back(AssignDoc(DefineVar(var, *frame, d), NullOpt, a)); + args.push_back(AssignDoc(DefineVar(var, *f, d), NullOpt, a)); } // Step 2. Handle `func->attrs` if (func->attrs.defined() && !func->attrs->dict.empty()) { - (*frame)->stmts.push_back( + (*f)->stmts.push_back( ExprStmtDoc(TIR(d, "func_attr") // ->Call({d->AsDoc(func->attrs, p->Attr("attrs"))}))); } @@ -125,10 +117,9 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) } ExprDoc param_doc = args[i]->lhs; ObjectPath buffer_p = p->Attr("buffer_map")->MapValue(param); - ExprDoc lhs = - DefineBuffer(buffer, *frame, d); // TODO(@junrushao): switch `lhs` and `rhs` - ExprDoc rhs = BufferDecl(buffer, "match_buffer", {param_doc}, buffer_p, *frame, d); - (*frame)->stmts.push_back(AssignDoc(lhs, rhs, NullOpt)); + ExprDoc lhs = DefineBuffer(buffer, *f, d); // TODO(@junrushao): switch `lhs` and `rhs` + ExprDoc rhs = BufferDecl(buffer, "match_buffer", {param_doc}, buffer_p, *f, d); + (*f)->stmts.push_back(AssignDoc(lhs, rhs, NullOpt)); } } // Step 4. Handle `func->body` @@ -154,18 +145,18 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) if (d->cfg->syntax_sugar && implicit_root_block) { tir::Block root_block = implicit_root_block.value(); ObjectPath root_block_p = p->Attr("body")->Attr("block"); - (*frame)->stmts.push_back(CommentDoc("with T.block(\"root\"):")); + (*f)->stmts.push_back(CommentDoc("with T.block(\"root\"):")); // Handle root block `alloc_buffer` for (int i = 0, n = root_block->alloc_buffers.size(); i < n; ++i) { tir::Buffer buffer = root_block->alloc_buffers[i]; ObjectPath buffer_p = root_block_p->Attr("alloc_buffers")->ArrayIndex(i); - IdDoc lhs = DefineBuffer(buffer, *frame, d); - ExprDoc rhs = BufferDecl(buffer, "alloc_buffer", {}, buffer_p, *frame, d); - (*frame)->stmts.push_back(AssignDoc(lhs, rhs, NullOpt)); + IdDoc lhs = DefineBuffer(buffer, *f, d); + ExprDoc rhs = BufferDecl(buffer, "alloc_buffer", {}, buffer_p, *f, d); + (*f)->stmts.push_back(AssignDoc(lhs, rhs, NullOpt)); } - AsDocBody(root_block->body, root_block_p->Attr("body"), frame->get(), d); + AsDocBody(root_block->body, root_block_p->Attr("body"), f->get(), d); } else { - AsDocBody(func->body, p->Attr("body"), frame->get(), d); + AsDocBody(func->body, p->Attr("body"), f->get(), d); } Optional ret_type = NullOpt; if (func->ret_type.defined()) { @@ -174,21 +165,15 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) ret_type = d->AsDoc(func->ret_type, p->Attr("ret_type")); } } - return FunctionDoc( - /*name=*/IdDoc(FindFunctionName(d, func)), - /*args=*/args, - /*decorators=*/{TIR(d, "prim_func")}, - /*return_type=*/ret_type, - /*body=*/(*frame)->stmts); + return HeaderWrapper(d, FunctionDoc( + /*name=*/IdDoc(FindFunctionName(d, func).value_or("main")), + /*args=*/args, + /*decorators=*/{TIR(d, "prim_func")}, + /*return_type=*/ret_type, + /*body=*/(*f)->stmts)); }); -std::string ReprPrintPrimFunc(const ObjectRef& obj, const PrinterConfig& cfg) { - IRDocsifier d(cfg); - Doc doc = HeaderWrapper(d, d->AsDoc(obj, ObjectPath::Root())); - return DocToPythonScript(doc, cfg); -} - -TVM_SCRIPT_REPR(tir::PrimFuncNode, ReprPrintPrimFunc); +TVM_SCRIPT_REPR(tir::PrimFuncNode, ReprPrintTIR); } // namespace printer } // namespace script diff --git a/src/script/printer/tir/utils.h b/src/script/printer/tir/utils.h index 18c64c5edcfe..08eb12bfa785 100644 --- a/src/script/printer/tir/utils.h +++ b/src/script/printer/tir/utils.h @@ -24,6 +24,7 @@ #include #include #include +#include #include #include #include @@ -72,12 +73,6 @@ class TIRFrame : public Frame { TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(TIRFrame, Frame, TIRFrameNode); }; -/*! \brief Creates the TIR common prefix, which is by default `T` */ -inline ExprDoc TIR(const IRDocsifier& d, const String& attr) { - d->ir_usage.insert("tir"); - return IdDoc(d->cfg->tir_prefix)->Attr(attr); -} - /*! * \brief Defines a variable in the IRDocsifier at the given frame, * and returns the corresponding IdDoc @@ -171,30 +166,15 @@ inline Optional FindLowestVarDef(const ObjectRef& var, const IRDocsifier& return NullOpt; } -/*! - * \brief Create a frame and add dispatch token. Calculate LCA information for the frame. - * \param d The IRDocsifier - * \param root The root of the TIR AST - * \param tir The TIR to be saved in the new TIR frame - * \return The frame created - */ -inline TIRFrame MakeDispatchFrame(const IRDocsifier& d, const ObjectRef& root, - const ObjectRef& tir) { - d->SetCommonPrefix(root, [](const ObjectRef& obj) { - return obj->IsInstance() || obj->IsInstance(); - }); - TIRFrame frame(d, tir); - frame->AddDispatchToken(d, "tir"); - return frame; -} - /*! \brief Redirected method for the ReprPrinter */ inline std::string ReprPrintTIR(const ObjectRef& obj, const PrinterConfig& cfg) { IRDocsifier d(cfg); - With f(MakeDispatchFrame(d, obj, ObjectRef(nullptr))); - std::ostringstream oss; - oss << Docsify(obj, d, *f, cfg); - return oss.str(); + d->SetCommonPrefix(obj, [](const ObjectRef& obj) { + return obj->IsInstance() || obj->IsInstance(); + }); + With f(d, ObjectRef{nullptr}); + (*f)->AddDispatchToken(d, "tir"); + return Docsify(obj, d, *f, cfg); } /*! diff --git a/src/script/printer/utils.h b/src/script/printer/utils.h index e90fbc0fb39d..ade19b345215 100644 --- a/src/script/printer/utils.h +++ b/src/script/printer/utils.h @@ -19,10 +19,11 @@ #ifndef TVM_SCRIPT_PRINTER_UTILS_H_ #define TVM_SCRIPT_PRINTER_UTILS_H_ +#include #include #include -#include +#include #include #include @@ -39,9 +40,19 @@ inline void RedirectedReprPrinterMethod(const ObjectRef& obj, ReprPrinter* p) { try { p->stream << TVMScriptPrinter::Script(obj, NullOpt); } catch (const tvm::Error& e) { - LOG(WARNING) << "TVMScript printer falls back to the legacy ReprPrinter with the error:\n" - << e.what(); - p->stream << AsLegacyRepr(obj); + if (ReprLegacyPrinter::CanDispatch(obj)) { + LOG(WARNING) << "TVMScript printer falls back to the legacy ReprPrinter with the error:\n" + << e.what(); + try { + p->stream << AsLegacyRepr(obj); + } catch (const tvm::Error& e) { + LOG(WARNING) << "AsLegacyRepr fails. Falling back to the basic address printer"; + } + } else { + LOG(WARNING) << "TVMScript printer falls back to the basic address printer with the error:\n" + << e.what(); + } + p->stream << obj->GetTypeKey() << '(' << obj.get() << ')'; } } @@ -62,7 +73,37 @@ inline std::string Docsify(const ObjectRef& obj, const IRDocsifier& d, const Fra } else { LOG(FATAL) << "TypeError: Unexpected doc type: " << doc->GetTypeKey(); } - return DocToPythonScript(StmtBlockDoc(f->stmts), cfg); + std::ostringstream os; + if (!d->metadata.empty()) { + if (d->cfg->show_meta) { + os << "metadata = tvm.ir.load_json(" + << SaveJSON(Map(d->metadata.begin(), d->metadata.end())) << ")" + << "\n"; + } else { + f->stmts.push_back( + CommentDoc("Metadata omitted. Use show_meta=True in script() method to show it.")); + } + } + os << DocToPythonScript(StmtBlockDoc(f->stmts), cfg); + return os.str(); +} + +/*! \brief Creates the IR common prefix, which is by default `I` */ +inline ExprDoc IR(const IRDocsifier& d, const String& attr) { + d->ir_usage.insert("ir"); + return IdDoc(d->cfg->ir_prefix)->Attr(attr); +} + +/*! \brief Creates the TIR common prefix, which is by default `T` */ +inline ExprDoc TIR(const IRDocsifier& d, const String& attr) { + d->ir_usage.insert("tir"); + return IdDoc(d->cfg->tir_prefix)->Attr(attr); +} + +/*! \brief Creates the TIR common prefix, which is by default `T` */ +inline ExprDoc Relax(const IRDocsifier& d, const String& attr) { + d->ir_usage.insert("relax"); + return IdDoc(d->cfg->relax_prefix)->Attr(attr); } inline std::string DType2Str(const runtime::DataType& dtype) { @@ -89,6 +130,34 @@ inline Doc HeaderWrapper(const IRDocsifier& d, const Doc& doc) { return doc; } +inline Optional GetBindingName(const IRDocsifier& d) { + return d->cfg->binding_names.empty() ? Optional(NullOpt) : d->cfg->binding_names.back(); +} + +inline Optional FindFunctionName(const IRDocsifier& d, const BaseFunc& f) { + if (Optional name = GetBindingName(d)) { + return name.value(); + } + if (Optional sym = f->GetAttr(tvm::attr::kGlobalSymbol)) { + return sym.value(); + } + return NullOpt; +} + +inline String GenerateUniqueName(std::string name_hint, + const std::unordered_set& defined_names) { + for (char& c : name_hint) { + if (c != '_' && !std::isalnum(c)) { + c = '_'; + } + } + std::string name = name_hint; + for (int i = 1; defined_names.count(name) > 0; ++i) { + name = name_hint + "_" + std::to_string(i); + } + return name; +} + } // namespace printer } // namespace script } // namespace tvm diff --git a/src/tir/ir/index_map.cc b/src/tir/ir/index_map.cc index ee7e493b61e0..b85464c19f30 100644 --- a/src/tir/ir/index_map.cc +++ b/src/tir/ir/index_map.cc @@ -369,12 +369,6 @@ String IndexMapNode::ToPythonString() const { return String(oss.str()); } -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "index_map(" << op->ToPythonString() << ")"; - }); - TVM_REGISTER_NODE_TYPE(IndexMapNode); TVM_REGISTER_GLOBAL("tir.IndexMap")