diff --git a/python/tvm/relax/expr.py b/python/tvm/relax/expr.py index 138724ed0693..f1cf815d8ea5 100644 --- a/python/tvm/relax/expr.py +++ b/python/tvm/relax/expr.py @@ -22,16 +22,18 @@ from typing import Any, Callable, Dict, List, Optional, Union import numpy as _np # type: ignore + import tvm import tvm._ffi -import tvm.relax import tvm.ir +import tvm.relax from tvm import DataType from tvm._ffi import base as _base -from tvm.runtime import ndarray as _nd, Object +from tvm.runtime import Object +from tvm.runtime import ndarray as _nd from ..ir import BaseFunc, Node, SourceName, Span -from ..runtime import String +from ..runtime import Scriptable, String from ..tir import PrimExpr from . import _ffi_api @@ -55,7 +57,7 @@ def __init__(self): # NOTE: place base struct info in expr to avoid cyclic dep # from expr to struct info. -class StructInfo(Node): +class StructInfo(Node, Scriptable): """The base class of all StructInfo. StructInfo contains both the static type @@ -110,7 +112,7 @@ def _binary_rhs_helper(rhs: "ExprWithOp") -> "ExprWithOp": raise TypeError(f"type {type(rhs)} not supported") -class ExprWithOp(Expr): +class ExprWithOp(Expr, Scriptable): """Basetype of all relax expressions that defines op overloading.""" def astype(self, dtype: Union[str, DataType]) -> "ExprWithOp": @@ -436,7 +438,7 @@ def __init__( @tvm._ffi.register_object("relax.expr.PrimValue") -class PrimValue(Expr): +class PrimValue(Expr, Scriptable): """The prim expr representing the value.""" value: PrimExpr @@ -448,7 +450,7 @@ def __init__(self, value: Union[PrimExpr, int], span: Span = None) -> None: @tvm._ffi.register_object("relax.expr.StringImm") -class StringImm(Expr): +class StringImm(Expr, Scriptable): """Represent a string literal constant.""" value: str @@ -458,7 +460,7 @@ def __init__(self, value: str, span: Span = None) -> None: @tvm._ffi.register_object("relax.expr.DataTypeImm") -class DataTypeImm(Expr): +class DataTypeImm(Expr, Scriptable): """Represent a data type constant.""" value: DataType @@ -468,11 +470,9 @@ def __init__(self, value: Union[DataType, str], span: Span = None) -> None: @tvm._ffi.register_object("relax.expr.Binding") -class Binding(Node): +class Binding(Node, Scriptable): """The base class of a binding in Relax.""" - ... - @tvm._ffi.register_object("relax.expr.MatchCast") class MatchCast(Binding): @@ -548,7 +548,7 @@ def __init__(self, blocks: List[BindingBlock], body: Expr, span: Span = None) -> @tvm._ffi.register_object("relax.expr.Function") -class Function(BaseFunc): +class Function(BaseFunc, Scriptable): """A Relax function.""" params: List[Var] @@ -588,35 +588,6 @@ def __call__(self, *args): """ return Call(self, args, None, None) - def script(self, show_meta: bool = False) -> str: - """Print relax.Function into TVMScript - - Parameters - ---------- - show_meta : bool - Whether to show meta information - - Returns - ------- - script : str - The TVM Script of the relax.Function - """ - return tvm._ffi.get_global_func("script.AsRelaxScript")(self, show_meta) # type: ignore - - def show(self, style: str = "light") -> None: - """ - A sugar for print highlighted TVM script. - - Parameters - ---------- - style : str, optional - Pygments styles extended by "light" (default) and "dark", by default "light" - """ - from tvm.script.highlight import cprint # pylint: disable=import-outside-toplevel - - # Use deferred import to avoid circular import while keeping cprint under tvm/script - cprint(self, style=style) - @tvm._ffi.register_object("relax.expr.ExternFunc") class ExternFunc(BaseFunc): diff --git a/src/relax/ir/expr.cc b/src/relax/ir/expr.cc index 45868a488a36..a0aaea886ddc 100644 --- a/src/relax/ir/expr.cc +++ b/src/relax/ir/expr.cc @@ -94,13 +94,6 @@ TVM_REGISTER_GLOBAL("relax.Call") .set_body_typed([](Expr op, Array args, Attrs attrs, Array sinfo_args, Span span) { return Call(op, args, attrs, sinfo_args, span); }); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "CallNode(" << node->op << ", " << node->args << ", " << node->attrs << ", " - << node->sinfo_args << ")"; - }); - If::If(Expr cond, Expr true_branch, Expr false_branch, Span span) { ObjectPtr n = make_object(); n->cond = std::move(cond); @@ -137,13 +130,6 @@ TVM_REGISTER_GLOBAL("relax.If") return If(cond, true_branch, false_branch, span); }); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "IfNode(" << node->cond << ", " << node->true_branch << ", " - << node->false_branch << ")"; - }); - Tuple::Tuple(tvm::Array fields, Span span) { ObjectPtr n = make_object(); n->fields = std::move(fields); @@ -179,12 +165,6 @@ Tuple WithFields(Tuple tuple, Optional> opt_fields, Optional o return tuple; } -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "Tuple(" << node->fields << ")"; - }); - TupleGetItem::TupleGetItem(Expr tuple, int index, Span span) { ObjectPtr n = make_object(); n->tuple = std::move(tuple); @@ -216,12 +196,6 @@ TVM_REGISTER_GLOBAL("relax.TupleGetItem").set_body_typed([](Expr tuple, int inde return TupleGetItem(tuple, index); }); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "TupleGetItemNode(" << node->tuple << ", " << node->index << ")"; - }); - TVM_REGISTER_NODE_TYPE(ShapeExprNode); ShapeExpr::ShapeExpr(Array values, Span span) { @@ -245,19 +219,6 @@ TVM_REGISTER_GLOBAL("relax.ShapeExpr").set_body_typed([](Array values, return ShapeExpr(values, span); }); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - const ShapeExprNode* node = static_cast(ref.get()); - p->stream << "ShapeExpr("; - for (auto it = node->values.begin(); it != node->values.end(); it++) { - if (it != node->values.begin()) { - p->stream << ", "; - } - p->stream << *it; - } - p->stream << ")"; - }); - TVM_REGISTER_NODE_TYPE(VarNode); Var::Var(Id vid, Optional struct_info_annotation, Span span) { @@ -572,12 +533,6 @@ TVM_REGISTER_GLOBAL("relax.ExternFunc").set_body_typed([](String global_symbol, return ExternFunc(global_symbol, span); }); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - const auto* node = static_cast(ref.get()); - p->stream << "ExternFunc(\"" << node->global_symbol << "\")"; - }); - Expr GetShapeOf(const Expr& expr) { // default case, to be normalized. ICHECK(expr->struct_info_.defined()) << "GetShapeOf can only be applied to normalized expr"; diff --git a/src/relax/ir/struct_info.cc b/src/relax/ir/struct_info.cc index 9db7cea6725d..4004ad28d560 100644 --- a/src/relax/ir/struct_info.cc +++ b/src/relax/ir/struct_info.cc @@ -41,11 +41,6 @@ TVM_REGISTER_GLOBAL("relax.ObjectStructInfo").set_body_typed([](Span span) { return ObjectStructInfo(span); }); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - p->stream << "ObjectStructInfo()"; - }); - // Prim PrimStructInfo::PrimStructInfo(DataType dtype, Span span) { ObjectPtr n = make_object(); @@ -60,12 +55,6 @@ TVM_REGISTER_GLOBAL("relax.PrimStructInfo").set_body_typed([](DataType dtype, Sp return PrimStructInfo(dtype, span); }); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - const auto* node = static_cast(ref.get()); - p->stream << "PrimStructInfo(" << node->dtype << ")"; - }); - // Shape ShapeStructInfo::ShapeStructInfo(Array values, Span span) { ObjectPtr n = make_object(); @@ -102,16 +91,6 @@ TVM_REGISTER_GLOBAL("relax.ShapeStructInfo") } }); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - const auto* node = static_cast(ref.get()); - if (node->values.defined()) { - p->stream << "ShapeStructInfo(" << node->values.value() << ")"; - } else { - p->stream << "ShapeStructInfo(ndim=" << node->ndim << ")"; - } - }); - // Tensor TensorStructInfo::TensorStructInfo(Expr shape, DataType dtype, Span span) { ObjectPtr n = make_object(); @@ -150,16 +129,6 @@ TVM_REGISTER_GLOBAL("relax.TensorStructInfo") } }); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - const auto* node = static_cast(ref.get()); - if (node->shape.defined()) { - p->stream << "TensorStructInfo(" << node->shape.value() << ", " << node->dtype << ")"; - } else { - p->stream << "TensorStructInfo(" << node->dtype << ", ndim=" << node->ndim << ")"; - } - }); - // Tuple TupleStructInfo::TupleStructInfo(Array fields, Span span) { ObjectPtr n = make_object(); @@ -175,12 +144,6 @@ TVM_REGISTER_GLOBAL("relax.TupleStructInfo") return TupleStructInfo(fields, span); }); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - const auto* node = static_cast(ref.get()); - p->stream << "TupleStructInfo(" << node->fields << ")"; - }); - // Func FuncStructInfo::FuncStructInfo(Array params, StructInfo ret, Span span) { ObjectPtr n = make_object(); @@ -223,12 +186,6 @@ TVM_REGISTER_GLOBAL("relax.FuncStructInfoOpaqueFunc") } }); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - const auto* node = static_cast(ref.get()); - p->stream << "FuncStructInfo(" << node->params << ", " << node->ret << ")"; - }); - // Helper functions void UpdateStructInfo(Expr expr, StructInfo struct_info) { ICHECK(!expr->struct_info_.defined()) diff --git a/src/script/printer/relax/binding.cc b/src/script/printer/relax/binding.cc new file mode 100644 index 000000000000..8a50fe969850 --- /dev/null +++ b/src/script/printer/relax/binding.cc @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "./utils.h" + +namespace tvm { +namespace script { +namespace printer { + +IfDoc PrintIfExpr(const relax::If& n, const ObjectPath& n_p, const IRDocsifier& d, // + const Optional& var, const Optional& ann) { + using relax::SeqExpr; + ExprDoc cond = d->AsDoc(n->cond, n_p->Attr("cond")); + std::vector> branches{ + PrintSeqExpr(Downcast(n->true_branch), n_p->Attr("true_branch"), d, false), + PrintSeqExpr(Downcast(n->false_branch), n_p->Attr("false_branch"), d, false), + }; + if (var.defined()) { + for (Array& stmts : branches) { + ExprDoc ret = Downcast(stmts.back())->expr; + stmts.Set(stmts.size() - 1, AssignDoc(var.value(), ret, ann)); + } + } + return IfDoc(cond, branches[0], branches[1]); +} + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch( + "", [](relax::MatchCast n, ObjectPath n_p, IRDocsifier d) -> Doc { + using relax::StructInfo; + using relax::MatchStructInfo; + Optional ann = StructInfoAsAnn(n->var, n_p->Attr("var"), d, n->value); + ExprDoc rhs = Relax(d, "match_cast") + ->Call({d->AsDoc(n->value, n_p->Attr("value")), + d->AsDoc(n->struct_info, n_p->Attr("struct_info_"))}); + ExprDoc lhs = DefineVar(n->var, d->frames.back(), d); + return AssignDoc(lhs, rhs, ann); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch( // + "", [](relax::VarBinding n, ObjectPath n_p, IRDocsifier d) -> Doc { + if (const auto if_ = n->value.as()) { + Optional ann = StructInfoAsAnn(n->var, n_p->Attr("var"), d, n->value); + ExprDoc lhs = DefineVar(n->var, d->frames.back(), d); + return PrintIfExpr(GetRef(if_), n_p->Attr("value"), d, lhs, ann); + } else if (n->value->IsInstance()) { + IdDoc lhs = DefineVar(n->var, d->frames.back(), d); + d->cfg->binding_names.push_back(lhs->name); + Doc ret = d->AsDoc(n->value, n_p->Attr("value")); + d->cfg->binding_names.pop_back(); + return ret; + } else { + ExprDoc rhs = d->AsDoc(n->value, n_p->Attr("value")); + Optional ann = StructInfoAsAnn(n->var, n_p->Attr("var"), d, n->value); + ExprDoc lhs = DefineVar(n->var, d->frames.back(), d); + return AssignDoc(lhs, rhs, ann); + } + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("", [](relax::If n, ObjectPath n_p, IRDocsifier d) -> Doc { + return PrintIfExpr(n, n_p, d, NullOpt, NullOpt); + }); + +TVM_SCRIPT_REPR(relax::MatchCastNode, ReprPrintRelax); +TVM_SCRIPT_REPR(relax::VarBindingNode, ReprPrintRelax); +TVM_SCRIPT_REPR(relax::IfNode, ReprPrintRelax); + +} // namespace printer +} // namespace script +} // namespace tvm diff --git a/src/script/printer/relax/call.cc b/src/script/printer/relax/call.cc new file mode 100644 index 000000000000..2feb2082c510 --- /dev/null +++ b/src/script/printer/relax/call.cc @@ -0,0 +1,212 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "./utils.h" + +namespace tvm { +namespace script { +namespace printer { + +class AttrPrinter : public tvm::AttrVisitor { + public: + explicit AttrPrinter(const ObjectPath& p, const IRDocsifier& d, Array* keys, + Array* values) + : p(p), d(d), keys(keys), values(values) {} + + void Visit(const char* key, double* value) final { + keys->push_back(key); + values->push_back(LiteralDoc::Float(*value, p->Attr(key))); + } + + void Visit(const char* key, int64_t* value) final { + keys->push_back(key); + values->push_back(LiteralDoc::Int(*value, p->Attr(key))); + } + + void Visit(const char* key, uint64_t* value) final { + keys->push_back(key); + values->push_back(LiteralDoc::Int(*value, p->Attr(key))); + } + + void Visit(const char* key, int* value) final { + keys->push_back(key); + values->push_back(LiteralDoc::Int(*value, p->Attr(key))); + } + + void Visit(const char* key, bool* value) final { + keys->push_back(key); + values->push_back(LiteralDoc::Boolean(*value, p->Attr(key))); + } + + void Visit(const char* key, std::string* value) final { + keys->push_back(key); + values->push_back(LiteralDoc::Str(*value, p->Attr(key))); + } + + void Visit(const char* key, DataType* value) final { + keys->push_back(key); + values->push_back(LiteralDoc::DataType(*value, p->Attr(key))); + } + + void Visit(const char* key, runtime::ObjectRef* value) final { + keys->push_back(key); + values->push_back(d->AsDoc(*value, p->Attr(key))); + } + + void Visit(const char* key, void** value) final { + LOG(FATAL) << "TypeError: void is not allowed in Attrs"; + } + + void Visit(const char* key, runtime::NDArray* value) final { + LOG(FATAL) << "TypeError: NDArray is not allowed in Attrs"; + } + + const ObjectPath& p; + const IRDocsifier& d; + Array* keys; + Array* values; +}; + +ExprDoc PrintCallee(const relax::Expr& n, const ObjectPath& n_p, const IRDocsifier& d) { + // TODO(@junrushao): handle callee better + if (const auto* ext = n.as()) { + return LiteralDoc::Str(ext->global_symbol, n_p); + } else if (const auto* gv = n.as()) { + IdDoc callee(gv->name_hint); + callee->source_paths.push_back(n_p); + return callee; + } else { + return d->AsDoc(n, n_p); + } +} + +Optional PrintCallTIR(const relax::Call& n, const ObjectPath& n_p, const IRDocsifier& d) { + static const Op& call_tir_op = Op::Get("relax.call_tir"); + if (!n->op.same_as(call_tir_op)) { + return NullOpt; + } + ICHECK(n->args.size() == 2 || n->args.size() == 3); + ICHECK(n->sinfo_args.size() == 1); + Array args; + Array kwargs_keys; + Array kwargs_values; + // Step 1. Print n->args[0], the callee + args.push_back(PrintCallee(n->args[0], n_p->Attr("args")->ArrayIndex(0), d)); + // Step 2. Print n->args[1], the input arguments + args.push_back(d->AsDoc(n->args[1], n_p->Attr("args")->ArrayIndex(1))); + // Step 3. Print n->sinfo_args, the output struct info + relax::StructInfo o_sinfo = n->sinfo_args[0]; + ObjectPath o_sinfo_p = n_p->Attr("sinfo_args")->ArrayIndex(0); + kwargs_keys.push_back("out_sinfo"); + if (const auto* o = o_sinfo.as()) { + Array fields; + ObjectPath fields_p = o_sinfo_p->Attr("fields"); + for (int i = 0, l = o->fields.size(); i < l; ++i) { + fields.push_back(d->AsDoc(o->fields[i], fields_p->ArrayIndex(i))); + } + kwargs_values.push_back(ListDoc(fields)); + } else { + kwargs_values.push_back(d->AsDoc(o_sinfo, o_sinfo_p)); + } + // Step 4. Print n->args[2], the tir variables + if (n->args.size() == 3) { + kwargs_keys.push_back("tir_vars"); + kwargs_values.push_back(d->AsDoc(n->args[2], n_p->Attr("args")->ArrayIndex(2))); + } + return Relax(d, "call_tir")->Call(args, kwargs_keys, kwargs_values); +} + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch( // + "", [](relax::Call n, ObjectPath n_p, IRDocsifier d) -> Doc { + // Special case: call_tir + if (Optional doc = PrintCallTIR(n, n_p, d)) { + return doc.value(); + } + ExprDoc prefix{nullptr}; + Array args; + Array kwargs_keys; + Array kwargs_values; + // Step 1. Print op + if (const auto* op = n->op.as()) { + prefix = Relax(d, "call_packed"); + args.push_back(LiteralDoc::Str(op->global_symbol, n_p->Attr("op"))); + } else if (const auto* op = n->op.as()) { + prefix = IdDoc(op->name_hint); + prefix->source_paths.push_back(n_p->Attr("op")); + } else if (const auto* op = n->op.as()) { + std::string name = op->name; + if (name.rfind("relax.", 0) == 0) { + prefix = Relax(d, name.substr(6)); + } else { + prefix = IdDoc(name); + } + prefix->source_paths.push_back(n_p->Attr("op")); + } else if (n->op->IsInstance()) { + prefix = d->AsDoc(n->op, n_p->Attr("op")); + } else { + LOG(FATAL) << "TypeError: Unsupported op: " << n->op->GetTypeKey(); + } + // Step 2. Print args + if (!n->args.empty()) { + args.push_back(PrintCallee(n->args[0], n_p->Attr("args")->ArrayIndex(0), d)); + } + for (int i = 1, l = n->args.size(); i < l; ++i) { + args.push_back(d->AsDoc(n->args[i], n_p->Attr("args")->ArrayIndex(i))); + } + // Step 3. Print attrs + if (n->attrs.defined()) { + if (n->op->IsInstance()) { + kwargs_keys.push_back("attrs_type_key"); + kwargs_values.push_back(LiteralDoc::Str(n->attrs->GetTypeKey(), n_p->Attr("attrs"))); + } + if (const auto* attrs = n->attrs.as()) { + std::vector> sorted; + for (const auto& kv : attrs->dict) { + sorted.push_back(kv); + } + std::sort(sorted.begin(), sorted.end()); + for (const auto& kv : sorted) { + kwargs_keys.push_back(kv.first); + kwargs_values.push_back( + d->AsDoc(kv.second, n_p->Attr("attrs")->Attr(kv.first))); + } + } else { + AttrPrinter printer(n_p->Attr("attrs"), d, &kwargs_keys, &kwargs_values); + const_cast(n->attrs.get())->VisitAttrs(&printer); + } + } + // Step 4. Print type_args + if (n->sinfo_args.size() > 0) { + ObjectPath sinfo_args_p = n_p->Attr("sinfo_args"); + Array sinfo_args; + for (int i = 0, l = n->sinfo_args.size(); i < l; ++i) { + sinfo_args.push_back( + d->AsDoc(n->sinfo_args[i], sinfo_args_p->ArrayIndex(i))); + } + kwargs_keys.push_back("sinfo_args"); + kwargs_values.push_back(TupleDoc(sinfo_args)); + } + return prefix->Call(args, kwargs_keys, kwargs_values); + }); + +TVM_SCRIPT_REPR(relax::CallNode, ReprPrintRelax); + +} // namespace printer +} // namespace script +} // namespace tvm diff --git a/src/script/printer/relax/expr.cc b/src/script/printer/relax/expr.cc new file mode 100644 index 000000000000..a786932fc3d9 --- /dev/null +++ b/src/script/printer/relax/expr.cc @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "./utils.h" + +namespace tvm { +namespace script { +namespace printer { + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch( // + "", [](relax::PrimValue n, ObjectPath n_p, IRDocsifier d) -> Doc { + // TODO(@junrushao): float numbers + return Relax(d, "prim_value")->Call({d->AsDoc(n->value, n_p->Attr("value"))}); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch( // + "", [](relax::StringImm n, ObjectPath n_p, IRDocsifier d) -> Doc { + return Relax(d, "str")->Call({LiteralDoc::Str(n->value, n_p->Attr("value"))}); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch( // + "", [](relax::DataTypeImm n, ObjectPath n_p, IRDocsifier d) -> Doc { + return Relax(d, "dtype")->Call({LiteralDoc::DataType(n->value, n_p->Attr("value"))}); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch( // + "", [](relax::Tuple n, ObjectPath n_p, IRDocsifier d) -> Doc { + // TODO(@junrushao): revisit tuple printing + if (n->fields.empty()) { + return Relax(d, "tuple")->Call({}); + } + Array fields_doc; + ObjectPath fields_p = n_p->Attr("fields"); + for (int i = 0, l = n->fields.size(); i < l; ++i) { + fields_doc.push_back(d->AsDoc(n->fields[i], fields_p->ArrayIndex(i))); + } + return TupleDoc(fields_doc); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch( // + "", [](relax::TupleGetItem n, ObjectPath n_p, IRDocsifier d) -> Doc { + ExprDoc idx = LiteralDoc::Int(n->index, n_p->Attr("index")); + return d->AsDoc(n->tuple, n_p->Attr("tuple"))[{idx}]; + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch( // + "", [](relax::ShapeExpr n, ObjectPath n_p, IRDocsifier d) -> Doc { + Array values_doc; + ObjectPath values_p = n_p->Attr("values"); + for (int i = 0, l = n->values.size(); i < l; ++i) { + values_doc.push_back(PrintShapeVar(n->values[i], values_p->ArrayIndex(i), d)); + } + return TupleDoc(values_doc); + }); + +Optional SpecialScalar(const runtime::NDArray& n, const ObjectPath& p) { + DataType dtype = n.DataType(); + const void* data = n->data; + if (n->ndim != 0 || n->device.device_type != kDLCPU) { + return NullOpt; + } + if (dtype == DataType::Int(32)) { + return LiteralDoc::Int(*reinterpret_cast(data), p); + } else if (dtype == DataType::Int(64)) { + return LiteralDoc::Int(*reinterpret_cast(data), p); + } else if (dtype == DataType::Float(32)) { + return LiteralDoc::Float(*reinterpret_cast(data), p); + } else if (dtype == DataType::Float(64)) { + return LiteralDoc::Float(*reinterpret_cast(data), p); + } else if (dtype == DataType::Bool()) { + return LiteralDoc::Boolean(*reinterpret_cast(data), p); + } else { + return NullOpt; + } +} + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch( // + "", [](relax::Constant n, ObjectPath n_p, IRDocsifier d) -> Doc { + if (Optional s = SpecialScalar(n->data, n_p->Attr("data"))) { + return Relax(d, "const") + ->Call({ + s.value(), + LiteralDoc::DataType(n->data.DataType(), n_p->Attr("data")->Attr("dtype")), + }); + } + return d->AddMetadata(n); + }); + +Doc PrintRelaxVar(relax::Var n, ObjectPath p, IRDocsifier d) { + if (!d->IsVarDefined(n)) { + ExprDoc ann = d->AsDoc(n->struct_info_, p->Attr("struct_info_")); + Frame f = d->frames.back(); + ExprDoc var = DefineVar(n, f, d); + f->stmts.push_back(AssignDoc(var, NullOpt, ann)); + } + return d->GetVarDoc(n).value(); +} + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable).set_dispatch("", PrintRelaxVar); +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable).set_dispatch("", PrintRelaxVar); + +TVM_SCRIPT_REPR(relax::PrimValueNode, ReprPrintRelax); +TVM_SCRIPT_REPR(relax::StringImmNode, ReprPrintRelax); +TVM_SCRIPT_REPR(relax::DataTypeImmNode, ReprPrintRelax); +TVM_SCRIPT_REPR(relax::TupleNode, ReprPrintRelax); +TVM_SCRIPT_REPR(relax::TupleGetItemNode, ReprPrintRelax); +TVM_SCRIPT_REPR(relax::ShapeExprNode, ReprPrintRelax); +TVM_SCRIPT_REPR(relax::VarNode, ReprPrintRelax); +TVM_SCRIPT_REPR(relax::DataflowVarNode, ReprPrintRelax); +TVM_SCRIPT_REPR(relax::ConstantNode, ReprPrintRelax); + +} // namespace printer +} // namespace script +} // namespace tvm diff --git a/src/script/printer/relax/function.cc b/src/script/printer/relax/function.cc new file mode 100644 index 000000000000..fa085fcad403 --- /dev/null +++ b/src/script/printer/relax/function.cc @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "./utils.h" + +namespace tvm { +namespace script { +namespace printer { + +TVM_REGISTER_NODE_TYPE(RelaxFrameNode); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("", [](relax::Function n, ObjectPath n_p, IRDocsifier d) -> Doc { + std::unordered_set func_vars; + With f(d); + (*f)->AddDispatchToken(d, "relax"); + (*f)->is_func = true; + (*f)->func_vars = &func_vars; + // Step 1. Print the return type + Optional ret_type = NullOpt; + if (const auto& func_sinfo = relax::MatchStructInfo(n)) { + ret_type = d->AsDoc(func_sinfo.value()->ret, // + n_p->Attr("struct_info_")->Attr("ret")); + } + // Step 2. Print params + Array params; + { + ObjectPath params_p = n_p->Attr("params"); + for (int i = 0, l = n->params.size(); i < l; ++i) { + params.push_back(AssignDoc( + /*lhs=*/DefineVar(n->params[i], *f, d), + /*rhs=*/NullOpt, StructInfoAsAnn(n->params[i], params_p->ArrayIndex(i), d, NullOpt))); + } + } + // Step 3. Clean up func variables + (*f)->func_vars = nullptr; + // Step 4. Print attributes + if (n->attrs.defined() && !n->attrs->dict.empty()) { + (*f)->stmts.push_back( + ExprStmtDoc(Relax(d, "func_attr") // + ->Call({d->AsDoc(n->attrs, n_p->Attr("attrs"))}))); + } + // Step 5. Print body + Array body = + PrintSeqExpr(Downcast(n->body), n_p->Attr("body"), d, /*use_ret=*/true); + (*f)->stmts.insert((*f)->stmts.end(), body.begin(), body.end()); + return HeaderWrapper(d, FunctionDoc(IdDoc(FindFunctionName(d, n).value_or("main")), params, + {Relax(d, "function")}, ret_type, (*f)->stmts)); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch( // + "", [](relax::ExternFunc n, ObjectPath n_p, IRDocsifier d) -> Doc { + // TODO(@junrushao): print more information out of extern function. + return ExprStmtDoc(LiteralDoc::Str(n->global_symbol, n_p)); + }); + +TVM_SCRIPT_REPR(relax::FunctionNode, ReprPrintRelax); +TVM_SCRIPT_REPR(relax::ExternFuncNode, ReprPrintRelax); + +} // namespace printer +} // namespace script +} // namespace tvm diff --git a/src/script/printer/relax/region.cc b/src/script/printer/relax/region.cc new file mode 100644 index 000000000000..1ac0b5ba14df --- /dev/null +++ b/src/script/printer/relax/region.cc @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "./utils.h" + +namespace tvm { +namespace script { +namespace printer { + +Array PrintSeqExpr(const relax::SeqExpr& n, const ObjectPath& n_p, const IRDocsifier& d, + bool use_ret) { + With f(d); + const Array& blocks = n->blocks; + ObjectPath blocks_p = n_p->Attr("blocks"); + Array* stmts = &(*f)->stmts; + for (int i = 0, l = blocks.size(); i < l; ++i) { + Doc block = d->AsDoc(blocks[i], blocks_p->ArrayIndex(i)); + if (const auto* stmt_block = block.as()) { + stmts->insert(stmts->end(), stmt_block->stmts.begin(), stmt_block->stmts.end()); + } else if (const auto* stmt = block.as()) { + stmts->push_back(GetRef(stmt)); + } else { + LOG(FATAL) << "TypeError: Unknown type: " << block->GetTypeKey(); + } + } + ExprDoc ret = d->AsDoc(n->body, n_p->Attr("body")); + if (use_ret) { + stmts->push_back(ReturnDoc(ret)); + } else { + stmts->push_back(ExprStmtDoc(ret)); + } + return *stmts; +} + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("", [](relax::SeqExpr n, ObjectPath n_p, IRDocsifier d) -> Doc { + return StmtBlockDoc(PrintSeqExpr(n, n_p, d, false)); + }); + +Array PrintBindingBlock(const relax::BindingBlock& n, const ObjectPath& n_p, + const IRDocsifier& d, Array* non_dataflow_vars) { + const Array& bindings = n->bindings; + ObjectPath bindings_p = n_p->Attr("bindings"); + Array stmts; + for (int i = 0, l = bindings.size(); i < l; ++i) { + const relax::Binding& binding = bindings[i]; + ObjectPath binding_p = bindings_p->ArrayIndex(i); + ICHECK(binding->var.defined()); + Doc binding_doc = d->AsDoc(binding, binding_p); + if (const auto* stmt = binding_doc.as()) { + stmts.push_back(GetRef(stmt)); + } else if (const auto* stmt_block = binding_doc.as()) { + stmts.insert(stmts.end(), stmt_block->stmts.begin(), stmt_block->stmts.end()); + } else { + LOG(FATAL) << "TypeError: Unknown type: " << binding_doc->GetTypeKey(); + } + if (non_dataflow_vars != nullptr && !binding->var->IsInstance()) { + non_dataflow_vars->push_back(d->AsDoc(binding->var, binding_p->Attr("var"))); + } + } + return stmts; +} + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch( // + "", [](relax::BindingBlock n, ObjectPath n_p, IRDocsifier d) -> Doc { + return StmtBlockDoc(PrintBindingBlock(n, n_p, d, nullptr)); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch( // + "", [](relax::DataflowBlock n, ObjectPath n_p, IRDocsifier d) -> Doc { + Array non_dataflow_vars; + Array stmts = PrintBindingBlock(n, n_p, d, &non_dataflow_vars); + stmts.push_back(ExprStmtDoc(Relax(d, "output")->Call(non_dataflow_vars))); + return ScopeDoc(NullOpt, Relax(d, "dataflow")->Call({}), stmts); + }); + +TVM_SCRIPT_REPR(relax::SeqExprNode, ReprPrintRelax); +TVM_SCRIPT_REPR(relax::BindingBlockNode, ReprPrintRelax); +TVM_SCRIPT_REPR(relax::DataflowBlockNode, ReprPrintRelax); + +} // namespace printer +} // namespace script +} // namespace tvm diff --git a/src/script/printer/relax/struct_info.cc b/src/script/printer/relax/struct_info.cc new file mode 100644 index 000000000000..6f4a66c991d9 --- /dev/null +++ b/src/script/printer/relax/struct_info.cc @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include + +#include "./utils.h" + +namespace tvm { +namespace script { +namespace printer { + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch( // + "", [](relax::ObjectStructInfo n, ObjectPath n_p, IRDocsifier d) -> Doc { + return Relax(d, "Object"); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch( + "", [](relax::PrimStructInfo n, ObjectPath n_p, IRDocsifier d) -> Doc { + return Relax(d, "Prim")->Call({LiteralDoc::DataType(n->dtype, n_p->Attr("dtype"))}); + }); + +ExprDoc PrintShapeVar(const PrimExpr& e, const ObjectPath& e_p, const IRDocsifier& d) { + ExprDoc expr_doc = d->AsDoc(e, e_p); + // Step 1. Find if `func_vars` are being collected + const RelaxFrameNode* f = nullptr; + for (const Frame& frame : d->frames) { + if (const auto* relax_frame = frame.as()) { + if (relax_frame->func_vars) { + f = relax_frame; + break; + } + } + } + // Step 2. Figure out if the PrimExpr contains at least a func var + bool func_var_mode = false; + if (f != nullptr) { + tir::PostOrderVisit(e, [f, &func_var_mode](const ObjectRef& obj) -> void { + if (const auto* var = obj.as()) { + if (f->func_vars->count(var)) { + func_var_mode = true; + } + } + }); + } + // Step 3. Stringify the PrimExpr if func var exists + if (func_var_mode) { + return LiteralDoc::Str(DocToPythonScript(expr_doc, d->cfg), e_p); + } + return expr_doc; +} + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch( + "", [](relax::ShapeStructInfo n, ObjectPath n_p, IRDocsifier d) -> Doc { + if (n->values.defined()) { + Array shape = n->values.value(); + ObjectPath shape_p = n_p->Attr("values"); + Array shape_docs; + for (int i = 0, ndim = shape.size(); i < ndim; ++i) { + shape_docs.push_back(PrintShapeVar(shape[i], shape_p->ArrayIndex(i), d)); + } + return Relax(d, "Shape")->Call({ListDoc(shape_docs)}); + } + return Relax(d, "Shape") + ->Call({}, {"ndim"}, {LiteralDoc::Int(n->ndim, n_p->Attr("ndim"))}); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch( // + "", [](relax::TensorStructInfo n, ObjectPath n_p, IRDocsifier d) -> Doc { + Array args; + Array kwargs_keys; + Array kwargs_values; + if (n->shape.defined()) { + args.push_back(d->AsDoc(n->shape.value(), n_p->Attr("shape"))); + } + if (!n->IsUnknownDtype()) { + kwargs_keys.push_back("dtype"); + kwargs_values.push_back(LiteralDoc::DataType(n->dtype, n_p->Attr("dtype"))); + } + if (!n->shape.defined() && !n->IsUnknownNdim()) { + kwargs_keys.push_back("ndim"); + kwargs_values.push_back(LiteralDoc::Int(n->ndim, n_p->Attr("ndim"))); + } + if (args.empty() && kwargs_keys.empty()) { + return Relax(d, "Tensor"); + } + return Relax(d, "Tensor")->Call(args, kwargs_keys, kwargs_values); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch( // + "", [](relax::TupleStructInfo n, ObjectPath n_p, IRDocsifier d) -> Doc { + if (n->fields.empty()) { + return Relax(d, "Tuple"); + } + Array fields_doc; + ObjectPath fields_p = n_p->Attr("fields"); + for (int i = 0, l = n->fields.size(); i < l; ++i) { + fields_doc.push_back(d->AsDoc(n->fields[i], fields_p->ArrayIndex(i))); + } + return Relax(d, "Tuple")->Call(fields_doc); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch( // + "", [](relax::FuncStructInfo n, ObjectPath n_p, IRDocsifier d) -> Doc { + if (n->IsOpaque()) { + return Relax(d, "Callable"); + } + // TODO(@junrushao): track symbolic shape relation + Array params_doc; + Array params = n->params.value(); + ObjectPath params_p = n_p->Attr("params"); + for (int i = 0, n_params = params.size(); i < n_params; ++i) { + params_doc.push_back(d->AsDoc(params[i], params_p->ArrayIndex(i))); + } + return Relax(d, "Callable") + ->Call({TupleDoc(params_doc), // + d->AsDoc(n->ret, n_p->Attr("ret"))}); + }); + +TVM_SCRIPT_REPR(relax::ObjectStructInfoNode, ReprPrintRelax); +TVM_SCRIPT_REPR(relax::PrimStructInfoNode, ReprPrintRelax); +TVM_SCRIPT_REPR(relax::ShapeStructInfoNode, ReprPrintRelax); +TVM_SCRIPT_REPR(relax::TensorStructInfoNode, ReprPrintRelax); +TVM_SCRIPT_REPR(relax::TupleStructInfoNode, ReprPrintRelax); +TVM_SCRIPT_REPR(relax::FuncStructInfoNode, ReprPrintRelax); + +} // namespace printer +} // namespace script +} // namespace tvm diff --git a/src/script/printer/relax/tir.cc b/src/script/printer/relax/tir.cc new file mode 100644 index 000000000000..2c8bb0f1da6c --- /dev/null +++ b/src/script/printer/relax/tir.cc @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include + +#include "./utils.h" + +namespace tvm { +namespace script { +namespace printer { + +Doc PrintTIRVar(tir::Var n, ObjectPath n_p, IRDocsifier d) { + ICHECK(n->dtype.is_int() && n->dtype.is_scalar()) << "TypeError: Relax only uses " + "scalar integer TIR variables, but gets: " + << n; + if (!d->IsVarDefined(n)) { + // Find the outmost Relax function frame. If not exist, the outmost Relax frame. + RelaxFrameNode* f = nullptr; + for (const Frame& frame : d->frames) { + if (const auto* relax_frame = frame.as()) { + if (relax_frame->is_func) { + f = const_cast(relax_frame); + break; + } else if (f == nullptr) { + f = const_cast(relax_frame); + } + } + } + // There should be at least one Relax frame + if (f == nullptr) { + LOG(FATAL) << "IndexError: No relax environment is found when printing a TIR var under " + "relax's dispatch token"; + } + // If the Relax function frame is collecting func vars + if (f->func_vars) { + ICHECK(f->is_func); + f->func_vars->insert(n.get()); + } + IdDoc var = d->Define(n, GetRef(f), n->name_hint.empty() ? "v" : n->name_hint); + var->source_paths.push_back(n_p); + f->stmts.push_back(AssignDoc(var, + TIR(d, "Var")->Call({ + LiteralDoc::Str(var->name, n_p->Attr("name_hint")), + LiteralDoc::DataType(n->dtype, n_p->Attr("dtype")), + }), + NullOpt)); + } + if (Optional doc = d->GetVarDoc(n)) { + return doc.value(); + } + LOG(FATAL) << "IndexError: Variable is not defined in the environment: " << n; +} + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable).set_dispatch("relax", PrintTIRVar); +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable).set_dispatch("relax", PrintTIRVar); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch( // + "relax", [](tvm::IntImm n, ObjectPath n_p, IRDocsifier d) -> Doc { // + // TODO(@junrushao): support non-int64 cases + return LiteralDoc::Int(n->value, n_p); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch( // + "relax", [](tvm::GlobalVar n, ObjectPath n_p, IRDocsifier d) -> Doc { // + IdDoc ret(n->name_hint); + ret->source_paths.push_back(n_p); + return ret; + }); + +} // namespace printer +} // namespace script +} // namespace tvm diff --git a/src/script/printer/relax/type.cc b/src/script/printer/relax/type.cc new file mode 100644 index 000000000000..d13d90b1d5ed --- /dev/null +++ b/src/script/printer/relax/type.cc @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "./utils.h" + +namespace tvm { +namespace script { +namespace printer { + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch( // + "", [](relax::ShapeType n, ObjectPath n_p, IRDocsifier d) -> Doc { + return Relax(d, "Shape") + ->Call({}, {"ndim"}, {LiteralDoc::Int(n->ndim, n_p->Attr("ndim"))}); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch( // + "", [](relax::ObjectType n, ObjectPath n_p, IRDocsifier d) -> Doc { + return Relax(d, "Object"); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch( + "", [](relax::DynTensorType n, ObjectPath n_p, IRDocsifier d) -> Doc { + return Relax(d, "Tensor") + ->Call({}, {"ndim", "dtype"}, + {LiteralDoc::Int(n->ndim, n_p->Attr("ndim")), + LiteralDoc::DataType(n->dtype, n_p->Attr("dtype"))}); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch( + "", [](relax::PackedFuncType n, ObjectPath n_p, IRDocsifier d) -> Doc { + return Relax(d, "PackedFunc"); // TODO(@junrushao): verify if this is correct + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch( // + "relax", [](tvm::TupleType n, ObjectPath n_p, IRDocsifier d) -> Doc { + if (n->fields.empty()) { + return Relax(d, "Tuple"); + } + Array fields_doc; + ObjectPath fields_p = n_p->Attr("fields"); + for (int i = 0, l = n->fields.size(); i < l; ++i) { + fields_doc.push_back(d->AsDoc(n->fields[i], fields_p->ArrayIndex(i))); + } + return Relax(d, "Tuple")->Call(fields_doc); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch( + "relax", [](tvm::FuncType n, ObjectPath n_p, IRDocsifier d) -> Doc { + Array arg_types_doc; + Array arg_types = n->arg_types; + ObjectPath arg_types_p = n_p->Attr("arg_types"); + for (int i = 0, n_params = arg_types.size(); i < n_params; ++i) { + arg_types_doc.push_back(d->AsDoc(arg_types[i], arg_types_p->ArrayIndex(i))); + } + return Relax(d, "Callable") + ->Call({TupleDoc(arg_types_doc), // + d->AsDoc(n->ret_type, n_p->Attr("ret_type"))}); + }); + +TVM_SCRIPT_REPR(relax::ShapeTypeNode, ReprPrintRelax); +TVM_SCRIPT_REPR(relax::ObjectTypeNode, ReprPrintRelax); +TVM_SCRIPT_REPR(relax::DynTensorTypeNode, ReprPrintRelax); +TVM_SCRIPT_REPR(relax::PackedFuncTypeNode, ReprPrintRelax); +TVM_REGISTER_GLOBAL("script.printer.ReprPrintRelax").set_body_typed(ReprPrintRelax); + +} // namespace printer +} // namespace script +} // namespace tvm diff --git a/src/script/printer/relax/utils.h b/src/script/printer/relax/utils.h new file mode 100644 index 000000000000..7702f7b22dd2 --- /dev/null +++ b/src/script/printer/relax/utils.h @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_SCRIPT_PRINTER_RELAX_UTILS_H_ +#define TVM_SCRIPT_PRINTER_RELAX_UTILS_H_ + +#include +#include +#include + +#include +#include +#include +#include + +#include "../utils.h" + +namespace tvm { +namespace script { +namespace printer { + +class RelaxFrameNode : public FrameNode { + public: + bool is_func = false; + std::unordered_set* func_vars = nullptr; + + void VisitAttrs(AttrVisitor* v) { + FrameNode::VisitAttrs(v); + v->Visit("is_global_func", &is_func); + // `func_var_to_define` is not visited + } + + static constexpr const char* _type_key = "script.printer.RelaxFrame"; + TVM_DECLARE_FINAL_OBJECT_INFO(RelaxFrameNode, FrameNode); +}; + +class RelaxFrame : public Frame { + public: + explicit RelaxFrame(const IRDocsifier& d) { + ObjectPtr n = make_object(); + n->stmts.clear(); + n->d = d.get(); + n->is_func = false; + n->func_vars = nullptr; + data_ = std::move(n); + } + + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(RelaxFrame, Frame, RelaxFrameNode); +}; + +/*! \brief Redirected method for the ReprPrinter */ +inline std::string ReprPrintRelax(const ObjectRef& obj, const PrinterConfig& cfg) { + IRDocsifier d(cfg); + With f(d); + (*f)->AddDispatchToken(d, "relax"); + return Docsify(obj, d, *f, cfg); +} + +inline IdDoc DefineVar(const relax::Var& var, const Frame& frame, const IRDocsifier& d) { + return d->Define(var, frame, var->name_hint().empty() ? "v" : var->name_hint()); +} + +inline Optional StructInfoAsAnn(const relax::Var& v, const ObjectPath& v_p, + const IRDocsifier& d, const Optional& rhs) { + if (!v->struct_info_.defined()) { + return NullOpt; + } + if (const auto* call = rhs.as()) { + static const Op& call_tir_op = Op::Get("relax.call_tir"); + if (call->op.same_as(call_tir_op)) { + return NullOpt; + } + } + return d->AsDoc(v->struct_info_, v_p->Attr("struct_info_")); +} + +Array PrintSeqExpr(const relax::SeqExpr& n, const ObjectPath& n_p, const IRDocsifier& d, + bool use_ret); + +ExprDoc PrintShapeVar(const PrimExpr& e, const ObjectPath& e_p, const IRDocsifier& d); + +} // namespace printer +} // namespace script +} // namespace tvm + +#endif // TVM_SCRIPT_PRINTER_RELAX_UTILS_H_ diff --git a/tests/python/relax/test_tvmscript_printer_relax.py b/tests/python/relax/test_tvmscript_printer_relax.py new file mode 100644 index 000000000000..75fc4d14296c --- /dev/null +++ b/tests/python/relax/test_tvmscript_printer_relax.py @@ -0,0 +1,489 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring +import pytest +from tvm import IRModule, relax, tir +from tvm.script import relax as R + + +def _assert_print(obj, expected): + if not isinstance(obj, str): + obj = obj.script(verbose_expr=True) + obj = obj.strip() + assert obj == expected.strip(), "\n" + obj + + +def test_function(): + @R.function + def func(a: R.Tensor((10, 10))) -> R.Tensor((10, 10)): # type: ignore + return a + + _assert_print( + func, + """ +# from tvm.script import relax as R + +@R.function +def main(a: R.Tensor((10, 10))) -> R.Tensor((10, 10)): + return a""", + ) + + +def test_extern_func(): + @R.function + def relax_func(a: R.Tensor((10, 10))) -> R.Tensor((10, 10)): # type: ignore + return a + + obj = IRModule( + { + "func": relax_func, + "my_ext": relax.ExternFunc("my_ext"), + } + ) + _assert_print( + obj, + """ +# from tvm.script import ir as I +# from tvm.script import relax as R + +@I.ir_module +class Module: + @R.function + def func(a: R.Tensor((10, 10))) -> R.Tensor((10, 10)): + return a + + "my_ext" +""", + ) + + +def test_object_struct_info(): + obj = relax.ObjectStructInfo() + _assert_print( + obj, + "R.Object", + ) + + +def test_prim_struct_info(): + obj = relax.PrimStructInfo("float32") + _assert_print(obj, 'R.Prim("float32")') + + +def test_shape_struct_info_0(): + obj = relax.ShapeStructInfo(ndim=-1) + _assert_print(obj, "R.Shape(ndim=-1)") + + +def test_shape_struct_info_1(): + obj = relax.ShapeStructInfo([1, 2, 3]) + _assert_print(obj, "R.Shape([1, 2, 3])") + + +def test_shape_struct_info_2(): + obj = relax.ShapeStructInfo([1, tir.Var("a", "int64"), 3]) + _assert_print( + obj, + """ +a = T.Var("a", "int64") +R.Shape([1, a, 3])""", + ) + + +def test_tensor_struct_info(): + obj = relax.TensorStructInfo( + shape=relax.ShapeExpr([1, tir.Var("a", "int64"), 3]), + dtype="float32", + ) + _assert_print( + obj, + """ +a = T.Var("a", "int64") +R.Tensor((1, a, 3), dtype="float32") +""", + ) + + +def test_tuple_struct_info_empty(): + obj = relax.TupleStructInfo([]) + _assert_print(obj, "R.Tuple") + + +def test_tuple_struct_info(): + obj = relax.TupleStructInfo( + [ + relax.PrimStructInfo("float32"), + relax.ObjectStructInfo(), + relax.ShapeStructInfo([1, tir.Var("a", "int64"), 3]), + ] + ) + _assert_print( + obj, + """ +a = T.Var("a", "int64") +R.Tuple(R.Prim("float32"), R.Object, R.Shape([1, a, 3])) +""", + ) + + +def test_func_struct_info(): + obj = relax.FuncStructInfo( + params=[ + relax.PrimStructInfo("float32"), + relax.ObjectStructInfo(), + relax.ShapeStructInfo([1, tir.Var("a", "int64"), 3]), + ], + ret=relax.TensorStructInfo( + shape=relax.ShapeExpr([1, 2, 3]), + dtype="float32", + ), + ) + _assert_print( + obj, + """ +a = T.Var("a", "int64") +R.Callable((R.Prim("float32"), R.Object, R.Shape([1, a, 3])), R.Tensor((1, 2, 3), dtype="float32")) +""", + ) + + +def test_shape_type(): + obj = relax.ShapeType(ndim=3) + _assert_print(obj, "R.Shape(ndim=3)") + + +def test_object_type(): + obj = relax.ObjectType() + _assert_print(obj, "R.Object") + + +def test_dyn_tensor_type(): + obj = relax.DynTensorType() + _assert_print(obj, 'R.Tensor(ndim=-1, dtype="float32")') + + +def test_packed_func_type(): + obj = relax.PackedFuncType() + _assert_print(obj, "R.PackedFunc") + + +def test_tuple_type(): + obj = relax.TupleType([relax.ShapeType(ndim=3), relax.ObjectType()]) + _assert_print( + obj._relax_script(), # pylint: disable=protected-access + "R.Tuple(R.Shape(ndim=3), R.Object)", + ) + + +def test_func_type(): + obj = relax.FuncType( + arg_types=[ + relax.ObjectType(), + relax.ShapeType(ndim=3), + ], + ret_type=relax.DynTensorType( + ndim=3, + dtype="float32", + ), + ) + _assert_print( + obj._relax_script(), # pylint: disable=protected-access + 'R.Callable((R.Object, R.Shape(ndim=3)), R.Tensor(ndim=3, dtype="float32"))', + ) + + +def test_prim_value(): + obj = relax.PrimValue(1) + _assert_print(obj, "R.prim_value(1)") + + +def test_string_imm(): + obj = relax.StringImm("hello") + _assert_print(obj, 'R.str("hello")') + + +def test_data_type_imm(): + obj = relax.DataTypeImm("float32") + _assert_print(obj, 'R.dtype("float32")') + + +def test_var(): + obj = relax.Var("a", relax.TensorStructInfo([1, tir.Var("x", "int64"), 3], "float32")) + _assert_print( + obj, + """ +x = T.Var("x", "int64") +a: R.Tensor((1, x, 3), dtype="float32") +a""", + ) + + +def test_dataflow_var(): + obj = relax.DataflowVar("a", relax.TensorStructInfo([1, tir.Var("x", "int64"), 3], "float32")) + _assert_print( + obj, + """ +x = T.Var("x", "int64") +a: R.Tensor((1, x, 3), dtype="float32") +a""", + ) + + +def test_tuple(): + obj = relax.Tuple( + [ + relax.Var("a", relax.TensorStructInfo([1, tir.Var("x", "int64"), 3], "float32")), + relax.Var("b", relax.TensorStructInfo([1, tir.Var("y", "int64"), 3], "float32")), + relax.Var("c", relax.TensorStructInfo([1, tir.Var("z", "int64"), 3], "float32")), + ] + ) + _assert_print( + obj, + """ +x = T.Var("x", "int64") +a: R.Tensor((1, x, 3), dtype="float32") +y = T.Var("y", "int64") +b: R.Tensor((1, y, 3), dtype="float32") +z = T.Var("z", "int64") +c: R.Tensor((1, z, 3), dtype="float32") +(a, b, c) +""", + ) + + +def test_tuple_get_item(): + obj = relax.TupleGetItem( + relax.Tuple( + [ + relax.Var("a", relax.TensorStructInfo([1, tir.Var("x", "int64"), 3], "float32")), + relax.Var("b", relax.TensorStructInfo([1, tir.Var("y", "int64"), 3], "float32")), + relax.Var("c", relax.TensorStructInfo([1, tir.Var("z", "int64"), 3], "float32")), + ] + ), + 0, + ) + _assert_print( + obj, + """ +x = T.Var("x", "int64") +a: R.Tensor((1, x, 3), dtype="float32") +y = T.Var("y", "int64") +b: R.Tensor((1, y, 3), dtype="float32") +z = T.Var("z", "int64") +c: R.Tensor((1, z, 3), dtype="float32") +(a, b, c)[0] +""", + ) + + +def test_shape_expr(): + obj = relax.ShapeExpr([1, 2, 3]) + _assert_print(obj, "(1, 2, 3)") + + +def test_call(): + x = tir.Var("x", "int64") + a = relax.Var("a", relax.TensorStructInfo([1, x, 3], "float32")) + obj = relax.call_tir("my_func", args=a, out_sinfo=a.struct_info, tir_vars=[x]) + _assert_print( + obj, + """ +x = T.Var("x", "int64") +a: R.Tensor((1, x, 3), dtype="float32") +R.call_tir("my_func", (a,), out_sinfo=R.Tensor((1, x, 3), dtype="float32"), tir_vars=(x,)) +""", + ) + + +@pytest.mark.skip(reason="`relax.op.sin` is not upstreamed yet") +def test_seq_expr(): + x = tir.Var("x", "int64") + a = relax.Var("a", relax.TensorStructInfo([1, x, 3], "float32")) + b = relax.DataflowVar("b", relax.TensorStructInfo([1, x, 3], "float32")) + c = relax.Var("c", relax.TensorStructInfo([1, x, 3], "float32")) + + obj = relax.SeqExpr( + blocks=[ + relax.DataflowBlock( + bindings=[ + relax.VarBinding(b, relax.op.sin(a)), + relax.VarBinding(c, relax.op.sin(b)), + ] + ), + ], + body=c, + ) + _assert_print( + obj, + """ +x = T.Var("x", "int64") +a: R.Tensor((1, x, 3), dtype="float32") +with R.dataflow(): + b: R.Tensor((1, x, 3), dtype="float32") = R.sin(a) + c: R.Tensor((1, x, 3), dtype="float32") = R.sin(b) + R.output(c) +c +""", + ) + + +@pytest.mark.skip(reason="`relax.op.sin` is not upstreamed yet") +def test_binding_block(): + x = tir.Var("x", "int64") + a = relax.Var("a", relax.TensorStructInfo([1, x, 3], "float32")) + b = relax.Var("b", relax.TensorStructInfo([1, x, 3], "float32")) + c = relax.Var("c", relax.TensorStructInfo([1, x, 3], "float32")) + obj = relax.BindingBlock( + bindings=[ + relax.VarBinding(b, relax.op.sin(a)), + relax.VarBinding(c, relax.op.sin(b)), + ] + ) + _assert_print( + obj, + """ +x = T.Var("x", "int64") +a: R.Tensor((1, x, 3), dtype="float32") +b: R.Tensor((1, x, 3), dtype="float32") = R.sin(a) +c: R.Tensor((1, x, 3), dtype="float32") = R.sin(b) +""", + ) + + +@pytest.mark.skip(reason="`relax.op.sin` is not upstreamed yet") +def test_dataflow_block(): + x = tir.Var("x", "int64") + a = relax.Var("a", relax.TensorStructInfo([1, x, 3], "float32")) + b = relax.DataflowVar("b", relax.TensorStructInfo([1, x, 3], "float32")) + c = relax.Var("c", relax.TensorStructInfo([1, x, 3], "float32")) + obj = relax.DataflowBlock( + bindings=[ + relax.VarBinding(b, relax.op.sin(a)), + relax.VarBinding(c, relax.op.sin(b)), + ] + ) + _assert_print( + obj, + """ +x = T.Var("x", "int64") +a: R.Tensor((1, x, 3), dtype="float32") +with R.dataflow(): + b: R.Tensor((1, x, 3), dtype="float32") = R.sin(a) + c: R.Tensor((1, x, 3), dtype="float32") = R.sin(b) + R.output(c) +""", + ) + + +def test_match_cast(): + x = tir.Var("x", "int64") + a = relax.Var("a", relax.TensorStructInfo([1, x, 3])) + b = relax.Var("b", relax.TensorStructInfo([1, 5, 3])) + obj = relax.MatchCast( + var=b, + value=a, + struct_info=b.struct_info, + ) + _assert_print( + obj, + """ +x = T.Var("x", "int64") +a: R.Tensor((1, x, 3), dtype="float32") +b: R.Tensor((1, 5, 3), dtype="float32") = R.match_cast(a, R.Tensor((1, 5, 3), dtype="float32")) +""", + ) + + +@pytest.mark.skip(reason="`relax.op.sin` is not upstreamed yet") +def test_var_binding(): + x = tir.Var("x", "int64") + a = relax.Var("a", relax.TensorStructInfo([1, x, 3], "float32")) + b = relax.Var("b", relax.TensorStructInfo([1, x, 3], "float32")) + obj = relax.VarBinding(b, relax.op.sin(a)) + _assert_print( + obj, + """ +x = T.Var("x", "int64") +a: R.Tensor((1, x, 3), dtype="float32") +b: R.Tensor((1, x, 3), dtype="float32") = R.sin(a) +""", + ) + + +def test_if(): + a = relax.Var("a", relax.TensorStructInfo([], "bool")) + b = relax.Var("b", relax.TensorStructInfo([1, 2, 3], "float32")) + c = relax.Var("c", relax.TensorStructInfo([1, 2, 3], "float32")) + obj = relax.If( + a, + relax.SeqExpr([], b), + relax.SeqExpr([], c), + ) + _assert_print( + obj, + """ +a: R.Tensor((), dtype="bool") +if a: + b: R.Tensor((1, 2, 3), dtype="float32") + b +else: + c: R.Tensor((1, 2, 3), dtype="float32") + c +""", + ) + + +if __name__ == "__main__": + test_function() + test_extern_func() + + test_object_struct_info() + test_prim_struct_info() + test_shape_struct_info_0() + test_shape_struct_info_1() + test_shape_struct_info_2() + test_tensor_struct_info() + test_tuple_struct_info_empty() + test_tuple_struct_info() + test_func_struct_info() + + test_shape_type() + test_object_type() + test_dyn_tensor_type() + test_packed_func_type() + test_tuple_type() + test_func_type() + + test_prim_value() + test_string_imm() + test_data_type_imm() + + test_var() + test_dataflow_var() + # + test_tuple() + test_tuple_get_item() + test_shape_expr() + test_call() + + test_seq_expr() + test_binding_block() + test_dataflow_block() + + test_match_cast() + test_var_binding() + test_if()