Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions include/tvm/node/repr_printer.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ class ReprLegacyPrinter {
TVM_DLL void Print(const ObjectRef& node);
/*! \brief Print indent to the stream */
TVM_DLL void PrintIndent();
/*! \brief Could the LegacyPrinter dispatch the node */
TVM_DLL static bool CanDispatch(const ObjectRef& node);
/*! \brief Return the ostream it maintains */
TVM_DLL std::ostream& Stream() const;
// Allow registration to be printer.
Expand Down
6 changes: 6 additions & 0 deletions include/tvm/node/script_printer.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ namespace tvm {

class PrinterConfigNode : public Object {
public:
/*! \brief A stack that tracks the names of the binding hierarchy */
Array<String> binding_names = {};
/*! \brief Whether or not to show metadata. */
bool show_meta = false;
/*! \brief The prefix of IR nodes */
std::string ir_prefix = "I";
/*! \brief The prefix of TIR nodes */
Expand Down Expand Up @@ -71,6 +75,8 @@ class PrinterConfigNode : public Object {
Map<ObjectRef, String> obj_to_annotate = Map<ObjectRef, String>();

void VisitAttrs(AttrVisitor* v) {
v->Visit("binding_names", &binding_names);
v->Visit("show_meta", &show_meta);
v->Visit("ir_prefix", &ir_prefix);
v->Visit("buffer_dtype", &buffer_dtype);
v->Visit("int_dtype", &int_dtype);
Expand Down
9 changes: 5 additions & 4 deletions include/tvm/script/printer/ir_docsifier.h
Original file line number Diff line number Diff line change
Expand Up @@ -141,10 +141,10 @@ class IRDocsifierNode : public Object {
* when converting IR node object to Doc.
*/
Array<String> dispatch_tokens;
/*! \brief The IRModule to be docsifier is handling */
Optional<IRModule> mod;
/*! \brief Mapping from a var to its info */
std::unordered_map<ObjectRef, VariableInfo, ObjectPtrHash, ObjectPtrEqual> obj2info;
/*! \brief Metadata printing */
std::unordered_map<String, Array<ObjectRef>> metadata;
/*! \brief The variable names used already */
std::unordered_set<String> defined_names;
/*! \brief Common prefixes of variable usages */
Expand All @@ -155,8 +155,8 @@ class IRDocsifierNode : public Object {
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("frames", &frames);
v->Visit("dispatch_tokens", &dispatch_tokens);
v->Visit("mod", &mod);
// `obj2info` is not visited
// `metadata` is not visited
// `defined_names` is not visited
// `common_prefix` is not visited
// `ir_usage` is not visited
Expand Down Expand Up @@ -204,7 +204,8 @@ class IRDocsifierNode : public Object {
* \return The doc for variable, if it exists in the table. Otherwise it returns NullOpt.
*/
Optional<ExprDoc> GetVarDoc(const ObjectRef& obj) const;

/*! \brief Add a TVM object to the metadata section*/
ExprDoc AddMetadata(const ObjectRef& obj);
/*!
* \brief Check if a variable exists in the table.
* \param obj The variable object.
Expand Down
113 changes: 93 additions & 20 deletions python/tvm/runtime/script_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
# specific language governing permissions and limitations
# under the License.
"""Configuration of TVMScript printer"""
from typing import List, Dict, Optional
from typing import Dict, List, Optional, Sequence

from tvm._ffi import register_object
from tvm._ffi import get_global_func, register_object
from tvm.runtime import Object

from . import _ffi_node_api
Expand All @@ -28,6 +28,8 @@
class PrinterConfig(Object):
"""Configuration of TVMScript printer"""

binding_names: Sequence[str]
show_meta: bool
ir_prefix: str
tir_prefix: str
relax_prefix: str
Expand All @@ -47,6 +49,8 @@ class PrinterConfig(Object):
def __init__(
self,
*,
name: Optional[str] = None,
show_meta: bool = False,
ir_prefix: str = "I",
tir_prefix: str = "T",
relax_prefix: str = "R",
Expand All @@ -65,38 +69,49 @@ def __init__(
) -> None:
if num_context_lines is None:
num_context_lines = -1
cfg = {
"show_meta": show_meta,
"ir_prefix": ir_prefix,
"tir_prefix": tir_prefix,
"relax_prefix": relax_prefix,
"buffer_dtype": buffer_dtype,
"int_dtype": int_dtype,
"float_dtype": float_dtype,
"verbose_expr": verbose_expr,
"indent_spaces": indent_spaces,
"print_line_numbers": print_line_numbers,
"num_context_lines": num_context_lines,
"syntax_sugar": syntax_sugar,
"path_to_underline": path_to_underline,
"path_to_annotate": path_to_annotate,
"obj_to_underline": obj_to_underline,
"obj_to_annotate": obj_to_annotate,
}

if name is not None:
cfg["name"] = name
self.__init_handle_by_constructor__(
_ffi_node_api.PrinterConfig, # type: ignore # pylint: disable=no-member
{
"ir_prefix": ir_prefix,
"tir_prefix": tir_prefix,
"relax_prefix": relax_prefix,
"buffer_dtype": buffer_dtype,
"int_dtype": int_dtype,
"float_dtype": float_dtype,
"verbose_expr": verbose_expr,
"indent_spaces": indent_spaces,
"print_line_numbers": print_line_numbers,
"num_context_lines": num_context_lines,
"syntax_sugar": syntax_sugar,
"path_to_underline": path_to_underline,
"path_to_annotate": path_to_annotate,
"obj_to_underline": obj_to_underline,
"obj_to_annotate": obj_to_annotate,
},
_ffi_node_api.PrinterConfig, cfg # type: ignore # pylint: disable=no-member
)


def _script(obj: Object, config: PrinterConfig) -> str:
return _ffi_node_api.TVMScriptPrinterScript(obj, config) # type: ignore # pylint: disable=no-member


def _relax_script(obj: Object, config: PrinterConfig) -> str:
func = get_global_func("script.printer.ReprPrintRelax")
return func(obj, config)


Comment on lines +102 to +106
Copy link
Contributor

@leandron leandron Mar 21, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @junrushao - it seems these functions accidentally were merged into main as opposed to unity - can you push a follow-up PR and change it to the appropriate branch?

It is one of the cases that a change was merged quite soon (in about 5h I think) so not many people were able to look into it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case, I don;t think there is any issue future proofing the particular implementation. Given that non of the existing works are disrupted.

Ultimately i would lean more towards the maintainer of the modules, given they volunteered their time and effort building the module in a scoped way. So unless this s disrupting some of the existing modules, i would suggest let the maintainers make the call

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doing a bit more digging, this is actually technical debt/bug as script.printer.ReprPrintRelax is not something that exists in this code base, rather than any future proofing, so I'd like for this to be seriously considered for a fix.

Moreover, it is not really scoped as it is half here, and half in a separate branch.

Copy link
Member

@tqchen tqchen Mar 21, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is scoped in a sense that the name itself clearly differentiates itself from other functions and is confined within the printer module.

From a technical pov, it is common for code to have plugin registries that queries and runs function when available. Atm there does not seems to be a user facing regression triggered by the particular internal function, if there are I certainly trust the related maintainers for a fix. So i do not see it as a bug given my experience in this area.

Given this is an internal API here, i think the overall impact is scoped and minimum in this module (script printer) with clear name, folder structure etc, so again i will let the maintainers to make the call

class Scriptable:
"""A base class that enables the script() and show() method."""

def script(
self,
*,
name: Optional[str] = None,
show_meta: bool = False,
ir_prefix: str = "I",
tir_prefix: str = "T",
relax_prefix: str = "R",
Expand All @@ -117,6 +132,10 @@ def script(

Parameters
----------
name : Optional[str] = None
The name of the object
show_meta : bool = False
Whether to print the meta data of the object
ir_prefix : str = "I"
The prefix of AST nodes from tvm.ir
tir_prefix : str = "T"
Expand Down Expand Up @@ -156,6 +175,52 @@ def script(
return _script(
self,
PrinterConfig(
name=name,
show_meta=show_meta,
ir_prefix=ir_prefix,
tir_prefix=tir_prefix,
relax_prefix=relax_prefix,
buffer_dtype=buffer_dtype,
int_dtype=int_dtype,
float_dtype=float_dtype,
verbose_expr=verbose_expr,
indent_spaces=indent_spaces,
print_line_numbers=print_line_numbers,
num_context_lines=num_context_lines,
syntax_sugar=syntax_sugar,
path_to_underline=path_to_underline,
path_to_annotate=path_to_annotate,
obj_to_underline=obj_to_underline,
obj_to_annotate=obj_to_annotate,
),
)

def _relax_script(
self,
*,
name: Optional[str] = None,
show_meta: bool = False,
ir_prefix: str = "I",
tir_prefix: str = "T",
relax_prefix: str = "R",
buffer_dtype: str = "float32",
int_dtype: str = "int32",
float_dtype: str = "void",
verbose_expr: bool = False,
indent_spaces: int = 4,
print_line_numbers: bool = False,
num_context_lines: int = -1,
syntax_sugar: bool = True,
path_to_underline: Optional[List[ObjectPath]] = None,
path_to_annotate: Optional[Dict[ObjectPath, str]] = None,
obj_to_underline: Optional[List[Object]] = None,
obj_to_annotate: Optional[Dict[Object, str]] = None,
) -> str:
return _relax_script(
self,
PrinterConfig(
name=name,
show_meta=show_meta,
ir_prefix=ir_prefix,
tir_prefix=tir_prefix,
relax_prefix=relax_prefix,
Expand All @@ -179,6 +244,8 @@ def show(
style: Optional[str] = None,
black_format: bool = True,
*,
name: Optional[str] = None,
show_meta: bool = False,
ir_prefix: str = "I",
tir_prefix: str = "T",
relax_prefix: str = "R",
Expand All @@ -204,6 +271,10 @@ def show(
`tvm.script.highlight.cprint` for more details.
black_format: bool
If true (default), use the formatter Black to format the TVMScript
name : Optional[str] = None
The name of the object
show_meta : bool = False
Whether to print the meta data of the object
ir_prefix : str = "I"
The prefix of AST nodes from tvm.ir
tir_prefix : str = "T"
Expand Down Expand Up @@ -241,6 +312,8 @@ def show(

cprint(
self.script(
name=name,
show_meta=show_meta,
ir_prefix=ir_prefix,
tir_prefix=tir_prefix,
relax_prefix=relax_prefix,
Expand Down
12 changes: 11 additions & 1 deletion python/tvm/script/ir_builder/tir/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@

# pylint: disable=unused-import
from tvm.target.codegen import llvm_lookup_intrinsic_id
from tvm.tir import Buffer, BufferRegion, PrimExpr
from tvm.tir import Buffer, BufferRegion, IndexMap, PrimExpr
from tvm.tir import op as _tir_op
from tvm.tir import type_annotation

Expand Down Expand Up @@ -1522,6 +1522,15 @@ def comm_reducer(combiner: Callable, identity: List[PrimExpr]) -> CommReducer:
return CommReducer(args[: num_args // 2], args[num_args // 2 :], res, identity)


def index_map(
mapping: Callable,
*,
inverse_index_map: Optional[Callable] = None,
) -> IndexMap:
"""Create a TIR Index mapping"""
return IndexMap.from_func(mapping, inverse_index_map=inverse_index_map)


def target(target_config: Union[Dict, str]) -> Target:
"""
Create a target
Expand Down Expand Up @@ -1824,6 +1833,7 @@ def wrapped(*args, **kwargs):
"max",
"iter_var",
"comm_reducer",
"index_map",
"target",
"buffer_var",
"abs",
Expand Down
12 changes: 11 additions & 1 deletion src/node/repr_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,20 @@ void ReprLegacyPrinter::Print(const ObjectRef& node) {
} else if (f.can_dispatch(node)) {
f(node, this);
} else {
stream << node; // Use ReprPrinter
try {
stream << node; // Use ReprPrinter
} catch (const tvm::Error& e) {
LOG(WARNING) << "ReprPrinter fails";
stream << node->GetTypeKey() << '(' << node.get() << ')';
}
}
}

bool ReprLegacyPrinter::CanDispatch(const ObjectRef& node) {
static const FType& f = vtable();
return !node.defined() || f.can_dispatch(node);
}

void ReprLegacyPrinter::PrintIndent() {
for (int i = 0; i < indent; ++i) {
stream << ' ';
Expand Down
6 changes: 6 additions & 0 deletions src/node/script_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@ std::string TVMScriptPrinter::Script(const ObjectRef& node, const Optional<Print

PrinterConfig::PrinterConfig(Map<String, ObjectRef> config_dict) {
runtime::ObjectPtr<PrinterConfigNode> n = make_object<PrinterConfigNode>();
if (auto v = config_dict.Get("name")) {
n->binding_names.push_back(Downcast<String>(v));
}
if (auto v = config_dict.Get("show_meta")) {
n->show_meta = Downcast<IntImm>(v)->value;
}
if (auto v = config_dict.Get("ir_prefix")) {
n->ir_prefix = Downcast<String>(v);
}
Expand Down
8 changes: 8 additions & 0 deletions src/relay/ir/function.cc
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,14 @@ TVM_REGISTER_GLOBAL("relay.ir.FuncWithAttr")
return NullOpt;
});

TVM_REGISTER_GLOBAL("relay.ir.FuncWithoutAttr")
.set_body_typed([](BaseFunc func, String key) -> Optional<Function> {
if (func->IsInstance<relay::FunctionNode>()) {
return WithoutAttr(Downcast<relay::Function>(std::move(func)), key);
}
return NullOpt;
});

TVM_REGISTER_NODE_TYPE(FunctionNode);

TVM_REGISTER_GLOBAL("relay.ir.Function")
Expand Down
30 changes: 17 additions & 13 deletions src/script/printer/ir/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,18 +42,24 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
}
return lhs_name < rhs_name;
});
ICHECK(!d->mod.defined());
d->mod = mod;
{
With<IRFrame> f(d);
(*f)->AddDispatchToken(d, "ir");
for (const auto& kv : functions) {
GlobalVar gv = kv.first;
BaseFunc func = kv.second;
(*f)->stmts.push_back(d->AsDoc<FunctionDoc>(func, p->Attr("functions")->MapValue(gv)));
With<IRFrame> f(d);
(*f)->AddDispatchToken(d, "ir");
for (const auto& kv : functions) {
GlobalVar gv = kv.first;
BaseFunc func = kv.second;
d->cfg->binding_names.push_back(gv->name_hint);
Doc doc = d->AsDoc(func, p->Attr("functions")->MapValue(gv));
d->cfg->binding_names.pop_back();
if (const auto* stmt_block = doc.as<StmtBlockDocNode>()) {
(*f)->stmts.push_back(stmt_block->stmts.back());
} else if (const auto* stmt = doc.as<StmtDocNode>()) {
(*f)->stmts.push_back(GetRef<StmtDoc>(stmt));
} else {
(*f)->stmts.push_back(Downcast<FunctionDoc>(doc));
}
return ClassDoc(IdDoc("Module"), {IR(d, "ir_module")}, (*f)->stmts);
}
return HeaderWrapper(d, ClassDoc(IdDoc(GetBindingName(d).value_or("Module")),
{IR(d, "ir_module")}, (*f)->stmts));
});

TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
Expand Down Expand Up @@ -119,9 +125,7 @@ std::string ReprPrintIRModule(const ObjectRef& mod, const PrinterConfig& cfg) {
return s.value();
}
}
IRDocsifier d(cfg);
Doc doc = HeaderWrapper(d, d->AsDoc(mod, ObjectPath::Root()));
return DocToPythonScript(doc, cfg);
return ReprPrintIR(mod, cfg);
}

TVM_SCRIPT_REPR(TypeVarNode, ReprPrintIR);
Expand Down
Loading