From 0f31975826d8a4467bfe65889cdf96915e366e65 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Tue, 17 Jan 2023 10:49:07 -0800 Subject: [PATCH] [TVMScript] Use TVMScript for all TIR Printing This PR clarifies TVMScript is the only possible format for TIR printing by default. To do this, the PR makes the following changes: - `AsText` and `PrettyPrint` are moved to Relay, so that in code, it stresses that they are specifically used for non-TIR usecases - The use of legacy behavior, for example, `TIRTextPrinterDebug`, `PrettyPrint`, `AsText`, `AsLegacyRepr`, are still available in both C++ and Python, but are discouraged in a way that one needs to include extra headers. - `astext` method is removed from TIR nodes, but remain available for Relay and others --- CMakeLists.txt | 2 +- include/tvm/ir/module.h | 28 ---- include/tvm/ir/transform.h | 1 - include/tvm/relay/base.h | 28 ++++ include/tvm/{ir => relay}/error.h | 13 +- include/tvm/relay/expr.h | 1 - include/tvm/relay/expr_functor.h | 2 +- include/tvm/relay/pattern_functor.h | 2 +- python/tvm/ir/__init__.py | 1 - python/tvm/ir/affine_type.py | 2 +- python/tvm/ir/base.py | 31 ----- python/tvm/ir/expr.py | 32 ++++- python/tvm/ir/module.py | 28 ++++ python/tvm/ir/op.py | 31 ++++- python/tvm/ir/tensor_type.py | 2 +- python/tvm/micro/model_library_format.py | 7 +- python/tvm/relay/__init__.py | 1 + python/tvm/relay/base.py | 39 +++++- python/tvm/relay/dataflow_pattern/__init__.py | 29 +++- python/tvm/relay/expr.py | 34 ++++- python/tvm/relay/function.py | 29 +++- python/tvm/script/__init__.py | 1 - python/tvm/script/printer/__init__.py | 1 - python/tvm/script/printer/printer.py | 54 -------- rust/tvm/src/ir/expr.rs | 2 +- src/ir/transform.cc | 6 +- src/relay/analysis/annotated_region_set.cc | 2 +- src/relay/analysis/annotated_region_set.h | 2 +- src/relay/analysis/kind_check.cc | 2 +- src/relay/analysis/match_exhaustion.cc | 2 +- src/relay/analysis/type_solver.h | 2 +- src/relay/backend/contrib/ethosu/codegen.cc | 2 +- .../backend/contrib/ethosu/compiler_attrs.cc | 2 +- .../backend/contrib/ethosu/preprocess.cc | 2 +- src/relay/backend/contrib/uma/relay_to_tir.cc | 2 +- src/relay/backend/vm/compiler.cc | 2 +- src/relay/backend/vm/compiler.h | 2 +- src/relay/collage/partition_rule.h | 2 +- src/relay/ir/base.cc | 5 + src/{ => relay}/ir/error.cc | 11 +- src/relay/op/tensor/transform.cc | 2 +- src/relay/op/tensor/transform.h | 2 +- src/relay/op/type_relations.h | 2 +- src/{ => relay}/printer/doc.cc | 4 +- src/{ => relay}/printer/doc.h | 9 +- src/{ => relay}/printer/meta_data.h | 13 +- .../printer/model_library_format_printer.cc | 6 +- src/{ => relay}/printer/relay_text_printer.cc | 13 +- src/{ => relay}/printer/text_printer.cc | 9 +- src/{ => relay}/printer/text_printer.h | 47 +++---- src/{ => relay}/printer/tir_text_printer.cc | 28 ++-- .../printer/tir_text_printer_debug.cc | 4 +- .../printer/tir_text_printer_debug.h | 10 +- src/{ => relay}/printer/tvmscript_printer.cc | 85 ++++++------ .../transforms/merge_compiler_regions.cc | 2 +- src/relay/transforms/partition_graph.cc | 2 +- src/script/printer/printer.cc | 7 - src/tir/schedule/error.cc | 6 +- src/tir/transforms/install_debug_spans.cc | 4 +- tests/python/relay/test_ir_parser.py | 10 +- .../test_meta_schedule_schedule_rule_mlt.py | 3 +- tests/python/unittest/test_tir_nodes.py | 126 ------------------ .../test_tir_transform_lower_warp_memory.py | 9 +- .../test_tvmscript_printer_syntax_sugar.py | 69 ---------- .../unittest/test_tvmscript_printer_tir.py | 42 ++++++ 65 files changed, 447 insertions(+), 514 deletions(-) rename include/tvm/{ir => relay}/error.h (97%) delete mode 100644 python/tvm/script/printer/printer.py rename src/{ => relay}/ir/error.cc (97%) rename src/{ => relay}/printer/doc.cc (98%) rename src/{ => relay}/printer/doc.h (97%) rename src/{ => relay}/printer/meta_data.h (95%) rename src/{ => relay}/printer/model_library_format_printer.cc (96%) rename src/{ => relay}/printer/relay_text_printer.cc (99%) rename src/{ => relay}/printer/text_printer.cc (95%) rename src/{ => relay}/printer/text_printer.h (95%) rename src/{ => relay}/printer/tir_text_printer.cc (97%) rename src/{ => relay}/printer/tir_text_printer_debug.cc (98%) rename src/{ => relay}/printer/tir_text_printer_debug.h (90%) rename src/{ => relay}/printer/tvmscript_printer.cc (96%) delete mode 100644 tests/python/unittest/test_tvmscript_printer_syntax_sugar.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 4a8d8b733ee1..36f7d379620f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -288,7 +288,6 @@ tvm_file_glob(GLOB_RECURSE COMPILER_SRCS src/topi/*.cc src/driver/*.cc src/parser/*.cc - src/printer/*.cc src/support/*.cc src/script/*.cc ) @@ -317,6 +316,7 @@ tvm_file_glob(GLOB RELAY_BACKEND_SRCS ) tvm_file_glob(GLOB_RECURSE RELAY_IR_SRCS src/relay/ir/*.cc + src/relay/printer/*.cc ) tvm_file_glob(GLOB_RECURSE RELAY_QNN_SRCS src/relay/qnn/*.cc diff --git a/include/tvm/ir/module.h b/include/tvm/ir/module.h index f26e640f6c22..4cd357d4180b 100644 --- a/include/tvm/ir/module.h +++ b/include/tvm/ir/module.h @@ -446,34 +446,6 @@ class IRModule : public ObjectRef { TVM_DEFINE_OBJECT_REF_COW_METHOD(IRModuleNode); }; -/*! - * \brief Pretty print a node for debug purposes. - * - * \param node The node to be printed. - * \return The text reperesentation. - * \note This function does not show version or meta-data. - * Use AsText if you want to store the text. - * \sa AsText. - */ -TVM_DLL String PrettyPrint(const ObjectRef& node); - -/*! - * \brief Render the node as a string in the text format. - * - * \param node The node to be rendered. - * \param show_meta_data Whether to print meta data section. - * \param annotate An optional callback function for attaching - * additional comment block to an expr. - * - * \note We support a limited set of IR nodes that are part of - * relay IR and - * - * \sa PrettyPrint. - * \return The text representation. - */ -TVM_DLL String AsText(const ObjectRef& node, bool show_meta_data = true, - runtime::TypedPackedFunc annotate = nullptr); - namespace attr { // Following are attributes for IRModule only. diff --git a/include/tvm/ir/transform.h b/include/tvm/ir/transform.h index febcca5c0107..473e6291685d 100644 --- a/include/tvm/ir/transform.h +++ b/include/tvm/ir/transform.h @@ -57,7 +57,6 @@ #define TVM_IR_TRANSFORM_H_ #include -#include #include #include #include diff --git a/include/tvm/relay/base.h b/include/tvm/relay/base.h index e94bd2756e98..2825bcfc659a 100644 --- a/include/tvm/relay/base.h +++ b/include/tvm/relay/base.h @@ -120,6 +120,34 @@ class Id : public ObjectRef { TVM_DEFINE_OBJECT_REF_METHODS(Id, ObjectRef, IdNode); }; +/*! + * \brief Pretty print a node for debug purposes. + * + * \param node The node to be printed. + * \return The text reperesentation. + * \note This function does not show version or meta-data. + * Use AsText if you want to store the text. + * \sa AsText. + */ +TVM_DLL String PrettyPrint(const ObjectRef& node); + +/*! + * \brief Render the node as a string in the text format. + * + * \param node The node to be rendered. + * \param show_meta_data Whether to print meta data section. + * \param annotate An optional callback function for attaching + * additional comment block to an expr. + * + * \note We support a limited set of IR nodes that are part of + * relay IR and + * + * \sa PrettyPrint. + * \return The text representation. + */ +TVM_DLL String AsText(const ObjectRef& node, bool show_meta_data = true, + runtime::TypedPackedFunc annotate = nullptr); + } // namespace relay } // namespace tvm diff --git a/include/tvm/ir/error.h b/include/tvm/relay/error.h similarity index 97% rename from include/tvm/ir/error.h rename to include/tvm/relay/error.h index 6ff61781ac44..be34e2b8ae1a 100644 --- a/include/tvm/ir/error.h +++ b/include/tvm/relay/error.h @@ -16,13 +16,8 @@ * specific language governing permissions and limitations * under the License. */ - -/*! - * \file tvm/ir/error.h - * \brief Utilities for error tracking and reporting. - */ -#ifndef TVM_IR_ERROR_H_ -#define TVM_IR_ERROR_H_ +#ifndef TVM_RELAY_ERROR_H_ +#define TVM_RELAY_ERROR_H_ #include #include @@ -33,6 +28,7 @@ #include namespace tvm { +namespace relay { /*! * \brief A wrapper around std::stringstream to build error. * @@ -181,5 +177,6 @@ class ErrorReporter { std::unordered_map node_to_gv_; }; +} // namespace relay } // namespace tvm -#endif // TVM_IR_ERROR_H_ +#endif // TVM_RELAY_ERROR_H_ diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index 6847a53caad4..854050464d4a 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -57,7 +57,6 @@ using BaseFunc = tvm::BaseFunc; using BaseFuncNode = tvm::BaseFuncNode; using GlobalVar = tvm::GlobalVar; using GlobalVarNode = tvm::GlobalVarNode; -using tvm::PrettyPrint; /*! * \brief Constant tensor, backed by an NDArray on the cpu(0) device. diff --git a/include/tvm/relay/expr_functor.h b/include/tvm/relay/expr_functor.h index 280a1f8a6c29..2a295c9da7f9 100644 --- a/include/tvm/relay/expr_functor.h +++ b/include/tvm/relay/expr_functor.h @@ -25,9 +25,9 @@ #ifndef TVM_RELAY_EXPR_FUNCTOR_H_ #define TVM_RELAY_EXPR_FUNCTOR_H_ -#include #include #include +#include #include #include #include diff --git a/include/tvm/relay/pattern_functor.h b/include/tvm/relay/pattern_functor.h index 711d8323f158..9d2b6689b2c2 100644 --- a/include/tvm/relay/pattern_functor.h +++ b/include/tvm/relay/pattern_functor.h @@ -25,8 +25,8 @@ #ifndef TVM_RELAY_PATTERN_FUNCTOR_H_ #define TVM_RELAY_PATTERN_FUNCTOR_H_ -#include #include +#include #include #include diff --git a/python/tvm/ir/__init__.py b/python/tvm/ir/__init__.py index 9e81dd5519e1..4f63cbecd9d1 100644 --- a/python/tvm/ir/__init__.py +++ b/python/tvm/ir/__init__.py @@ -27,7 +27,6 @@ Span, assert_structural_equal, load_json, - pretty_print, save_json, structural_equal, structural_hash, diff --git a/python/tvm/ir/affine_type.py b/python/tvm/ir/affine_type.py index 8d185ae59a34..24126f94b9c4 100644 --- a/python/tvm/ir/affine_type.py +++ b/python/tvm/ir/affine_type.py @@ -32,7 +32,7 @@ def __ne__(self, other): return not self.__eq__(other) def __str__(self): - from tvm.ir import pretty_print # pylint: disable=import-outside-toplevel + from tvm.relay import pretty_print # pylint: disable=import-outside-toplevel return pretty_print(self) diff --git a/python/tvm/ir/base.py b/python/tvm/ir/base.py index a1e1d20d8823..b84a83d55843 100644 --- a/python/tvm/ir/base.py +++ b/python/tvm/ir/base.py @@ -23,40 +23,9 @@ from . import _ffi_api, json_compact -def pretty_print(obj: Object) -> None: - """Pretty print the object.""" - return _ffi_api.PrettyPrint(obj) # type: ignore # pylint: disable=no-member - - class Node(Object): """Base class of all IR Nodes, implements astext function.""" - def astext(self, show_meta_data=True, annotate=None): - """Get the text format of the expression. - - Parameters - ---------- - show_meta_data : bool - Whether to include meta data section in the text - if there is meta data. - - annotate: Optional[Object->str] - Optionally annotate function to provide additional - information in the comment block. - - Returns - ------- - text : str - The text format of the expression. - - Notes - ----- - The meta data section is necessary to fully parse the text format. - However, it can contain dumps that are big (e.g constant weights), - so it can be helpful to skip printing the meta data section. - """ - return _ffi_api.AsText(self, show_meta_data, annotate) - @tvm._ffi.register_object("SourceName") class SourceName(Object): diff --git a/python/tvm/ir/expr.py b/python/tvm/ir/expr.py index e16cd5ea9e2f..52af8407b7a0 100644 --- a/python/tvm/ir/expr.py +++ b/python/tvm/ir/expr.py @@ -17,9 +17,9 @@ """Common expressions data structures in the IR.""" import tvm._ffi -from .base import Node -from . import _ffi_api from ..runtime import const, convert +from . import _ffi_api +from .base import Node class BaseExpr(Node): @@ -91,6 +91,34 @@ def __call__(self, *args): "Do not know how to handle GlobalVar.__call__ for types {}".format(arg_types) ) + def astext(self, show_meta_data=True, annotate=None): + """Get the text format of the expression. + + Parameters + ---------- + show_meta_data : bool + Whether to include meta data section in the text + if there is meta data. + + annotate: Optional[Object->str] + Optionally annotate function to provide additional + information in the comment block. + + Returns + ------- + text : str + The text format of the expression. + + Notes + ----- + The meta data section is necessary to fully parse the text format. + However, it can contain dumps that are big (e.g constant weights), + so it can be helpful to skip printing the meta data section. + """ + from tvm.relay import astext # pylint: disable=import-outside-toplevel + + return astext(self, show_meta_data, annotate) + @tvm._ffi.register_object class Range(Node): diff --git a/python/tvm/ir/module.py b/python/tvm/ir/module.py index b184c3b0c3cf..51410049ec74 100644 --- a/python/tvm/ir/module.py +++ b/python/tvm/ir/module.py @@ -287,6 +287,34 @@ def with_attr(self, attr_key, attr_value): return _ffi_api.Module_WithAttr(self, attr_key, attr_value) + def astext(self, show_meta_data=True, annotate=None): + """Get the text format of the expression. + + Parameters + ---------- + show_meta_data : bool + Whether to include meta data section in the text + if there is meta data. + + annotate: Optional[Object->str] + Optionally annotate function to provide additional + information in the comment block. + + Returns + ------- + text : str + The text format of the expression. + + Notes + ----- + The meta data section is necessary to fully parse the text format. + However, it can contain dumps that are big (e.g constant weights), + so it can be helpful to skip printing the meta data section. + """ + from tvm.relay import astext # pylint: disable=import-outside-toplevel + + return astext(self, show_meta_data, annotate) + def script( self, *, diff --git a/python/tvm/ir/op.py b/python/tvm/ir/op.py index 49ac72b887e6..70aba979518e 100644 --- a/python/tvm/ir/op.py +++ b/python/tvm/ir/op.py @@ -17,8 +17,9 @@ # pylint: disable=invalid-name """Primitive operators in the TVM IR.""" import tvm._ffi -from .expr import RelayExpr + from . import _ffi_api +from .expr import RelayExpr @tvm._ffi.register_object("Op") @@ -28,6 +29,34 @@ class Op(RelayExpr): def __init__(self): raise RuntimeError("Cannot create op, use get instead") + def astext(self, show_meta_data=True, annotate=None): + """Get the text format of the expression. + + Parameters + ---------- + show_meta_data : bool + Whether to include meta data section in the text + if there is meta data. + + annotate: Optional[Object->str] + Optionally annotate function to provide additional + information in the comment block. + + Returns + ------- + text : str + The text format of the expression. + + Notes + ----- + The meta data section is necessary to fully parse the text format. + However, it can contain dumps that are big (e.g constant weights), + so it can be helpful to skip printing the meta data section. + """ + from tvm.relay import astext # pylint: disable=import-outside-toplevel + + return astext(self, show_meta_data, annotate) + @staticmethod def get(op_name): """Get the Op for a given name diff --git a/python/tvm/ir/tensor_type.py b/python/tvm/ir/tensor_type.py index 7313f3c2b42c..495e0fe868e5 100644 --- a/python/tvm/ir/tensor_type.py +++ b/python/tvm/ir/tensor_type.py @@ -56,6 +56,6 @@ def concrete_shape(self): return tuple(int(x) for x in self.shape) def __str__(self): - from tvm.ir import pretty_print # pylint: disable=import-outside-toplevel + from tvm.relay import pretty_print # pylint: disable=import-outside-toplevel return pretty_print(self) diff --git a/python/tvm/micro/model_library_format.py b/python/tvm/micro/model_library_format.py index 0f30c39ad476..fc32fe34d6c9 100644 --- a/python/tvm/micro/model_library_format.py +++ b/python/tvm/micro/model_library_format.py @@ -27,12 +27,13 @@ import tvm from tvm.micro import get_standalone_crt_dir + from .._ffi import get_global_func from ..contrib import utils from ..driver import build_module -from ..relay.backend import executor_factory -from ..relay.backend.name_transforms import to_c_variable_style, prefix_generated_name from ..relay import param_dict +from ..relay.backend import executor_factory +from ..relay.backend.name_transforms import prefix_generated_name, to_c_variable_style from ..tir import expr # This should be kept identical to runtime::symbol::tvm_module_main @@ -528,7 +529,7 @@ def _eval_shape(param_name, buffer_shape): # TODO(mbs): The device type is not unique, better would be to use target.kind.name target_device_type = target.get_target_device_type() ir_mod = ir_module_by_target[target] - printer = get_global_func("tir.ModelLibraryFormatPrinter")(False, None, False) + printer = get_global_func("relay.ir.ModelLibraryFormatPrinter")(False, None, False) with open(src_dir / f"tir-{target_device_type}.txt", "w") as f: f.write(printer["print"](ir_mod)) diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py index 97842738e5cd..5e5d1d5f18d8 100644 --- a/python/tvm/relay/__init__.py +++ b/python/tvm/relay/__init__.py @@ -29,6 +29,7 @@ from . import prelude from . import loops from . import scope_builder +from .base import pretty_print, astext from . import transform from . import analysis diff --git a/python/tvm/relay/base.py b/python/tvm/relay/base.py index 323a8f6e5a01..8667bfb1dfdc 100644 --- a/python/tvm/relay/base.py +++ b/python/tvm/relay/base.py @@ -17,15 +17,50 @@ # pylint: disable=no-else-return, unidiomatic-typecheck, unused-import """The base node types for the Relay language.""" import os -import tvm._ffi +import tvm._ffi +from tvm.ir import Node as RelayNode +from tvm.ir import SourceName, Span from tvm.runtime import Object -from tvm.ir import SourceName, Span, Node as RelayNode +from . import _ffi_api __STD_PATH__ = os.path.join(os.path.dirname(os.path.realpath(__file__)), "std") +def pretty_print(obj: Object) -> None: + """Pretty print the object.""" + return _ffi_api.PrettyPrint(obj) # type: ignore # pylint: disable=no-member + + +def astext(obj: Object, show_meta_data=True, annotate=None): + """Get the text format of the expression. + + Parameters + ---------- + obj : Object + The object to be printed. + show_meta_data : bool + Whether to include meta data section in the text + if there is meta data. + annotate: Optional[Object->str] + Optionally annotate function to provide additional + information in the comment block. + + Returns + ------- + text : str + The text format of the expression. + + Notes + ----- + The meta data section is necessary to fully parse the text format. + However, it can contain dumps that are big (e.g constant weights), + so it can be helpful to skip printing the meta data section. + """ + return _ffi_api.AsText(obj, show_meta_data, annotate) # type: ignore # pylint: disable=no-member + + @tvm._ffi.register_func("tvm.relay.std_path") def _std_path(): return __STD_PATH__ diff --git a/python/tvm/relay/dataflow_pattern/__init__.py b/python/tvm/relay/dataflow_pattern/__init__.py index 6c29825bc04d..6e19cafa747d 100644 --- a/python/tvm/relay/dataflow_pattern/__init__.py +++ b/python/tvm/relay/dataflow_pattern/__init__.py @@ -26,6 +26,7 @@ from ...ir import make_node from ...ir.base import Node from ...runtime import Object +from ..base import astext, pretty_print from ..op import get from . import _ffi as ffi @@ -47,10 +48,34 @@ class DFPattern(Node): """Base class of all Patterns.""" def __str__(self): - from tvm.ir import pretty_print # pylint: disable=import-outside-toplevel - return pretty_print(self) + def astext(self, show_meta_data=True, annotate=None): + """Get the text format of the expression. + + Parameters + ---------- + show_meta_data : bool + Whether to include meta data section in the text + if there is meta data. + + annotate: Optional[Object->str] + Optionally annotate function to provide additional + information in the comment block. + + Returns + ------- + text : str + The text format of the expression. + + Notes + ----- + The meta data section is necessary to fully parse the text format. + However, it can contain dumps that are big (e.g constant weights), + so it can be helpful to skip printing the meta data section. + """ + return astext(self, show_meta_data, annotate) + def __call__(self, *args): args = list(args) if len(args) == 1 and args[0] is None: diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py index 7d60e89b59b7..cb14552ac16e 100644 --- a/python/tvm/relay/expr.py +++ b/python/tvm/relay/expr.py @@ -30,7 +30,7 @@ from . import _ffi_api from . import ty as _ty -from .base import RelayNode +from .base import RelayNode, astext, pretty_print # alias relay expr as Expr. Expr = RelayExpr @@ -62,10 +62,34 @@ def astype(self, dtype): return _ffi_api.cast(self, dtype) def __str__(self): - from tvm.ir import pretty_print # pylint: disable=import-outside-toplevel - return pretty_print(self) + def astext(self, show_meta_data=True, annotate=None): + """Get the text format of the expression. + + Parameters + ---------- + show_meta_data : bool + Whether to include meta data section in the text + if there is meta data. + + annotate: Optional[Object->str] + Optionally annotate function to provide additional + information in the comment block. + + Returns + ------- + text : str + The text format of the expression. + + Notes + ----- + The meta data section is necessary to fully parse the text format. + However, it can contain dumps that are big (e.g constant weights), + so it can be helpful to skip printing the meta data section. + """ + return astext(self, show_meta_data, annotate) + def __neg__(self): return _op_make.negative(self) @@ -719,8 +743,6 @@ def __init__(self, sids, dev_types, sizes): self.__init_handle_by_constructor__(_ffi_api.StorageInfo, sids, dev_types, sizes) def __str__(self): - from tvm.ir import pretty_print # pylint: disable=import-outside-toplevel - return pretty_print(self) @property @@ -750,6 +772,4 @@ def __init__(self, expr_to_storage_info): self.__init_handle_by_constructor__(_ffi_api.StaticMemoryPlan, expr_to_storage_info) def __str__(self): - from tvm.ir import pretty_print # pylint: disable=import-outside-toplevel - return pretty_print(self) diff --git a/python/tvm/relay/function.py b/python/tvm/relay/function.py index ef3356450085..dc0636a9b3f4 100644 --- a/python/tvm/relay/function.py +++ b/python/tvm/relay/function.py @@ -23,6 +23,7 @@ from tvm.runtime import convert from . import _ffi_api +from .base import astext, pretty_print from .expr import Call @@ -68,10 +69,34 @@ def __call__(self, *args): return Call(self, args, None, None) def __str__(self): - from tvm.ir import pretty_print # pylint: disable=import-outside-toplevel - return pretty_print(self) + def astext(self, show_meta_data=True, annotate=None): + """Get the text format of the expression. + + Parameters + ---------- + show_meta_data : bool + Whether to include meta data section in the text + if there is meta data. + + annotate: Optional[Object->str] + Optionally annotate function to provide additional + information in the comment block. + + Returns + ------- + text : str + The text format of the expression. + + Notes + ----- + The meta data section is necessary to fully parse the text format. + However, it can contain dumps that are big (e.g constant weights), + so it can be helpful to skip printing the meta data section. + """ + return astext(self, show_meta_data, annotate) + @tvm._ffi.register_func("relay.FunctionWithFields") def FunctionWithFields( diff --git a/python/tvm/script/__init__.py b/python/tvm/script/__init__.py index 82bb698f2773..9283727ad41a 100644 --- a/python/tvm/script/__init__.py +++ b/python/tvm/script/__init__.py @@ -18,4 +18,3 @@ from .parser import ir, ir_module from .parser import parse as from_source from .parser import tir -from .printer import script diff --git a/python/tvm/script/printer/__init__.py b/python/tvm/script/printer/__init__.py index dc37ea1ff6a6..01d89dacbf52 100644 --- a/python/tvm/script/printer/__init__.py +++ b/python/tvm/script/printer/__init__.py @@ -20,4 +20,3 @@ in a roundtrippable way. """ from . import default -from .printer import script diff --git a/python/tvm/script/printer/printer.py b/python/tvm/script/printer/printer.py deleted file mode 100644 index 2ce6329dca08..000000000000 --- a/python/tvm/script/printer/printer.py +++ /dev/null @@ -1,54 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""The printer interface""" -from typing import Optional - -from tvm.runtime.object_path import ObjectPath - -from . import _ffi_api - - -def script( - obj, - indent_space: int = 4, - print_line_number: bool = False, - num_context_lines: int = -1, - path_to_underline: Optional[ObjectPath] = None, -): - """Print a TVM IR as a TVMScript text format. - - Parameters - ---------- - obj : object - An TVM object representing TVM IR - indent_space : int = 4 - The number of spaces to indent - print_line_number : bool = False - Whether to print line number - num_context_lines : int = -1 - The number of context lines to print. -1 means all lines. - path_to_underline : Optional[ObjectPath] - The path to underline in the script. - - Returns - ------- - script : str - The TVMScript text format - """ - return _ffi_api.Script( # type: ignore # pylint: disable=no-member - obj, indent_space, print_line_number, num_context_lines, path_to_underline - ) diff --git a/rust/tvm/src/ir/expr.rs b/rust/tvm/src/ir/expr.rs index 03d8a4920718..1a0e7aea39c9 100644 --- a/rust/tvm/src/ir/expr.rs +++ b/rust/tvm/src/ir/expr.rs @@ -90,7 +90,7 @@ impl GlobalVar { // TODO: figure out how to type the last argument runtime::TypedPackedFunc annotate) external! { - #[name("ir.AsText")] + #[name("relay.ir.AsText")] fn _as_text(object: ObjectRef, show_meta_data: i32, annotate: runtime::Function) -> TString; } diff --git a/src/ir/transform.cc b/src/ir/transform.cc index bfd0a5917556..9a669493ccb7 100644 --- a/src/ir/transform.cc +++ b/src/ir/transform.cc @@ -587,7 +587,11 @@ TVM_REGISTER_GLOBAL("transform.OverrideInstruments") Pass PrintIR(String header, bool show_meta_data) { auto pass_func = [header, show_meta_data](IRModule mod, const PassContext& ctx) { - LOG(INFO) << "PrintIR(" << header << "):\n" << AsText(mod, show_meta_data); + if (const auto* f = runtime::Registry::Get("relay.PrintIR")) { + (*f)(mod, header, show_meta_data); + } else { + LOG(INFO) << "PrintIR(" << header << "):\n" << mod; + } return mod; }; return CreateModulePass(pass_func, 0, "PrintIR", {}); diff --git a/src/relay/analysis/annotated_region_set.cc b/src/relay/analysis/annotated_region_set.cc index 53c680b722cd..ef21604d8a71 100644 --- a/src/relay/analysis/annotated_region_set.cc +++ b/src/relay/analysis/annotated_region_set.cc @@ -19,7 +19,7 @@ #include "annotated_region_set.h" -#include +#include #include #include diff --git a/src/relay/analysis/annotated_region_set.h b/src/relay/analysis/annotated_region_set.h index aca42397916c..443bd5ec1da3 100644 --- a/src/relay/analysis/annotated_region_set.h +++ b/src/relay/analysis/annotated_region_set.h @@ -27,9 +27,9 @@ #ifndef TVM_RELAY_ANALYSIS_ANNOTATED_REGION_SET_H_ #define TVM_RELAY_ANALYSIS_ANNOTATED_REGION_SET_H_ -#include #include #include +#include #include #include #include diff --git a/src/relay/analysis/kind_check.cc b/src/relay/analysis/kind_check.cc index 65b8516cb16c..f7a5e7bf2d12 100644 --- a/src/relay/analysis/kind_check.cc +++ b/src/relay/analysis/kind_check.cc @@ -31,9 +31,9 @@ * We check this by ensuring the `dtype` field of a Tensor always * contains a data type such as `int`, `float`, `uint`. */ -#include #include #include +#include namespace tvm { namespace relay { diff --git a/src/relay/analysis/match_exhaustion.cc b/src/relay/analysis/match_exhaustion.cc index 2a90b911b676..05d5b36e3614 100644 --- a/src/relay/analysis/match_exhaustion.cc +++ b/src/relay/analysis/match_exhaustion.cc @@ -27,8 +27,8 @@ * code correctness, since hitting an unmatched case results in a * dynamic error unless exhaustiveness is checked in advance. */ -#include #include +#include #include #include diff --git a/src/relay/analysis/type_solver.h b/src/relay/analysis/type_solver.h index 3bde1a1e3746..7940e347b3ea 100644 --- a/src/relay/analysis/type_solver.h +++ b/src/relay/analysis/type_solver.h @@ -24,8 +24,8 @@ #ifndef TVM_RELAY_ANALYSIS_TYPE_SOLVER_H_ #define TVM_RELAY_ANALYSIS_TYPE_SOLVER_H_ -#include #include +#include #include #include diff --git a/src/relay/backend/contrib/ethosu/codegen.cc b/src/relay/backend/contrib/ethosu/codegen.cc index afa17750d8a8..a622f96c81da 100644 --- a/src/relay/backend/contrib/ethosu/codegen.cc +++ b/src/relay/backend/contrib/ethosu/codegen.cc @@ -24,9 +24,9 @@ * Codegen. */ -#include #include #include +#include #include #include #include diff --git a/src/relay/backend/contrib/ethosu/compiler_attrs.cc b/src/relay/backend/contrib/ethosu/compiler_attrs.cc index 42add45b013c..6c825a18901a 100644 --- a/src/relay/backend/contrib/ethosu/compiler_attrs.cc +++ b/src/relay/backend/contrib/ethosu/compiler_attrs.cc @@ -17,9 +17,9 @@ * under the License. */ -#include #include #include +#include #include #include #include diff --git a/src/relay/backend/contrib/ethosu/preprocess.cc b/src/relay/backend/contrib/ethosu/preprocess.cc index 571a56ad97c0..a0e0ac772fb0 100644 --- a/src/relay/backend/contrib/ethosu/preprocess.cc +++ b/src/relay/backend/contrib/ethosu/preprocess.cc @@ -16,9 +16,9 @@ * specific language governing permissions and limitations * under the License. */ -#include #include #include +#include #include #include #include diff --git a/src/relay/backend/contrib/uma/relay_to_tir.cc b/src/relay/backend/contrib/uma/relay_to_tir.cc index 8aed69453158..ca3ae0ebec6b 100644 --- a/src/relay/backend/contrib/uma/relay_to_tir.cc +++ b/src/relay/backend/contrib/uma/relay_to_tir.cc @@ -23,9 +23,9 @@ * \brief this file contains the target hooks for the Universal Modular Accelerator Interface (UMA). */ -#include #include #include +#include #include #include #include diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index 9ba90b9f676d..fb23c4cc082a 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -25,11 +25,11 @@ #include "compiler.h" #include -#include #include #include #include #include +#include #include #include #include diff --git a/src/relay/backend/vm/compiler.h b/src/relay/backend/vm/compiler.h index 163ec399013b..9160ce0e2e42 100644 --- a/src/relay/backend/vm/compiler.h +++ b/src/relay/backend/vm/compiler.h @@ -25,7 +25,7 @@ #ifndef TVM_RELAY_BACKEND_VM_COMPILER_H_ #define TVM_RELAY_BACKEND_VM_COMPILER_H_ -#include +#include #include #include #include diff --git a/src/relay/collage/partition_rule.h b/src/relay/collage/partition_rule.h index 19e7f3ccebfb..ca68c9b086b0 100644 --- a/src/relay/collage/partition_rule.h +++ b/src/relay/collage/partition_rule.h @@ -31,7 +31,7 @@ #include #include -#include "../../printer/doc.h" +#include "../printer/doc.h" #include "./candidate_partition.h" #include "./combiner_rule.h" #include "./sub_graph.h" diff --git a/src/relay/ir/base.cc b/src/relay/ir/base.cc index 5f7b8747a751..5f913026080d 100644 --- a/src/relay/ir/base.cc +++ b/src/relay/ir/base.cc @@ -51,5 +51,10 @@ TVM_REGISTER_GLOBAL("ir.NodeSetSpan").set_body_typed([](ObjectRef node_ref, Span } }); +TVM_REGISTER_GLOBAL("relay.PrintIR") + .set_body_typed([](ObjectRef mod, String header, bool show_metadata) { + LOG(INFO) << "PrintIR(" << header << "):\n" << AsText(mod, show_metadata); + }); + } // namespace relay } // namespace tvm diff --git a/src/ir/error.cc b/src/relay/ir/error.cc similarity index 97% rename from src/ir/error.cc rename to src/relay/ir/error.cc index 26448d04005c..940efd91aa52 100644 --- a/src/ir/error.cc +++ b/src/relay/ir/error.cc @@ -16,13 +16,9 @@ * specific language governing permissions and limitations * under the License. */ - -/*! - * \file ir/error.cc - * \brief Utilities for error tracking and reporting. - */ -#include #include +#include +#include // clang-format off #include @@ -31,6 +27,7 @@ // clang-format on namespace tvm { +namespace relay { template using NodeMap = std::unordered_map; @@ -137,5 +134,5 @@ void ErrorReporter::ReportAt(const GlobalVar& global, const ObjectRef& node, } this->node_to_gv_.insert({node, global}); } - +} // namespace relay } // namespace tvm diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index c41eb0f8ad99..5c5cd6f4b721 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -23,8 +23,8 @@ */ #include "transform.h" -#include #include +#include #include #include #include diff --git a/src/relay/op/tensor/transform.h b/src/relay/op/tensor/transform.h index 3c638a59f46e..6c88aec8b957 100644 --- a/src/relay/op/tensor/transform.h +++ b/src/relay/op/tensor/transform.h @@ -24,8 +24,8 @@ #ifndef TVM_RELAY_OP_TENSOR_TRANSFORM_H_ #define TVM_RELAY_OP_TENSOR_TRANSFORM_H_ -#include #include +#include #include #include diff --git a/src/relay/op/type_relations.h b/src/relay/op/type_relations.h index 6d6d5f70c0c2..740766172ddc 100644 --- a/src/relay/op/type_relations.h +++ b/src/relay/op/type_relations.h @@ -25,7 +25,7 @@ #ifndef TVM_RELAY_OP_TYPE_RELATIONS_H_ #define TVM_RELAY_OP_TYPE_RELATIONS_H_ -#include +#include #include #include diff --git a/src/printer/doc.cc b/src/relay/printer/doc.cc similarity index 98% rename from src/printer/doc.cc rename to src/relay/printer/doc.cc index b06995fb1286..79313c9a587f 100644 --- a/src/printer/doc.cc +++ b/src/relay/printer/doc.cc @@ -30,9 +30,10 @@ #include #include -#include "../support/str_escape.h" +#include "../../support/str_escape.h" namespace tvm { +namespace relay { /*! * \brief Represent a piece of text in the doc. @@ -157,4 +158,5 @@ Doc Doc::Concat(const std::vector& vec, const Doc& sep) { } return seq; } +} // namespace relay } // namespace tvm diff --git a/src/printer/doc.h b/src/relay/printer/doc.h similarity index 97% rename from src/printer/doc.h rename to src/relay/printer/doc.h index dc6ba8952f3e..36f26d9bd24b 100644 --- a/src/printer/doc.h +++ b/src/relay/printer/doc.h @@ -23,8 +23,8 @@ * * Reference: Philip Wadler. A Prettier Printer. Journal of Functional Programming'98 */ -#ifndef TVM_PRINTER_DOC_H_ -#define TVM_PRINTER_DOC_H_ +#ifndef TVM_RELAY_PRINTER_DOC_H_ +#define TVM_RELAY_PRINTER_DOC_H_ #include #include @@ -35,6 +35,7 @@ #include namespace tvm { +namespace relay { /*! * \brief Doc atom node for the ADT. @@ -162,6 +163,6 @@ class Doc { /*! \brief Internal doc stream. */ std::vector stream_; }; - +} // namespace relay } // namespace tvm -#endif // TVM_PRINTER_DOC_H_ +#endif // TVM_RELAY_PRINTER_DOC_H_ diff --git a/src/printer/meta_data.h b/src/relay/printer/meta_data.h similarity index 95% rename from src/printer/meta_data.h rename to src/relay/printer/meta_data.h index ddf0d78087ee..2dfd594de7eb 100644 --- a/src/printer/meta_data.h +++ b/src/relay/printer/meta_data.h @@ -16,13 +16,8 @@ * specific language governing permissions and limitations * under the License. */ - -/*! - * \file tvm/printer/meta_data.h - * \brief Meta data context for printers. - */ -#ifndef TVM_PRINTER_META_DATA_H_ -#define TVM_PRINTER_META_DATA_H_ +#ifndef TVM_RELAY_PRINTER_META_DATA_H_ +#define TVM_RELAY_PRINTER_META_DATA_H_ #include @@ -32,6 +27,7 @@ #include "doc.h" namespace tvm { +namespace relay { /*! * \brief Meta data context for Printers * @@ -140,5 +136,6 @@ class TextMetaDataContext { /*! \brief map from meta data into its string representation */ std::unordered_map meta_repr_; }; +} // namespace relay } // namespace tvm -#endif // TVM_PRINTER_META_DATA_H_ +#endif // TVM_RELAY_PRINTER_META_DATA_H_ diff --git a/src/printer/model_library_format_printer.cc b/src/relay/printer/model_library_format_printer.cc similarity index 96% rename from src/printer/model_library_format_printer.cc rename to src/relay/printer/model_library_format_printer.cc index 4220aa00f5a4..76d0f1423d4f 100644 --- a/src/printer/model_library_format_printer.cc +++ b/src/relay/printer/model_library_format_printer.cc @@ -26,7 +26,7 @@ #include "text_printer.h" namespace tvm { -namespace printer { +namespace relay { class ModelLibraryFormatPrinter : public ::tvm::runtime::ModuleNode { public: @@ -69,7 +69,7 @@ class ModelLibraryFormatPrinter : public ::tvm::runtime::ModuleNode { TextPrinter text_printer_; }; -TVM_REGISTER_GLOBAL("tir.ModelLibraryFormatPrinter") +TVM_REGISTER_GLOBAL("relay.ir.ModelLibraryFormatPrinter") .set_body_typed([](bool show_meta_data, const runtime::TypedPackedFunc& annotate, bool show_warning) { @@ -77,5 +77,5 @@ TVM_REGISTER_GLOBAL("tir.ModelLibraryFormatPrinter") make_object(show_meta_data, annotate, show_warning)); }); -} // namespace printer +} // namespace relay } // namespace tvm diff --git a/src/printer/relay_text_printer.cc b/src/relay/printer/relay_text_printer.cc similarity index 99% rename from src/printer/relay_text_printer.cc rename to src/relay/printer/relay_text_printer.cc index 76cac28b07f7..cc86f9b56435 100644 --- a/src/printer/relay_text_printer.cc +++ b/src/relay/printer/relay_text_printer.cc @@ -40,10 +40,10 @@ #include #include -#include "../ir/attr_functor.h" -#include "../parser/meta_ref.h" -#include "../relay/analysis/dependency_graph.h" -#include "../support/scalars.h" +#include "../../ir/attr_functor.h" +#include "../../parser/meta_ref.h" +#include "../../support/scalars.h" +#include "../analysis/dependency_graph.h" #include "doc.h" #include "meta_data.h" #include "text_printer.h" @@ -970,10 +970,5 @@ Doc RelayTextPrinter::PrintSpan(const Span& span) { return doc; } -TVM_REGISTER_GLOBAL("ir.TextPrinter").set_body_typed([](ObjectRef node) { - auto text = AsText(node, false, nullptr); - return text; -}); - } // namespace relay } // namespace tvm diff --git a/src/printer/text_printer.cc b/src/relay/printer/text_printer.cc similarity index 95% rename from src/printer/text_printer.cc rename to src/relay/printer/text_printer.cc index 4d4113fef694..f51f7c3dfa57 100644 --- a/src/printer/text_printer.cc +++ b/src/relay/printer/text_printer.cc @@ -23,7 +23,7 @@ * that can be parsed by a parser. */ -#include "text_printer.h" +#include "./text_printer.h" #include @@ -31,6 +31,7 @@ #include namespace tvm { +namespace relay { static const char* kSemVer = "0.0.5"; @@ -124,8 +125,8 @@ String AsText(const ObjectRef& node, bool show_meta_data, return doc.str(); } -TVM_REGISTER_GLOBAL("ir.PrettyPrint").set_body_typed(PrettyPrint); - -TVM_REGISTER_GLOBAL("ir.AsText").set_body_typed(AsText); +TVM_REGISTER_GLOBAL("relay.ir.PrettyPrint").set_body_typed(PrettyPrint); +TVM_REGISTER_GLOBAL("relay.ir.AsText").set_body_typed(AsText); +} // namespace relay } // namespace tvm diff --git a/src/printer/text_printer.h b/src/relay/printer/text_printer.h similarity index 95% rename from src/printer/text_printer.h rename to src/relay/printer/text_printer.h index 925c2ebf494e..707bbec5ad33 100644 --- a/src/printer/text_printer.h +++ b/src/relay/printer/text_printer.h @@ -23,8 +23,8 @@ * that can be parsed by a parser. */ -#ifndef TVM_PRINTER_TEXT_PRINTER_H_ -#define TVM_PRINTER_TEXT_PRINTER_H_ +#ifndef TVM_RELAY_PRINTER_TEXT_PRINTER_H_ +#define TVM_RELAY_PRINTER_TEXT_PRINTER_H_ #include #include @@ -41,19 +41,16 @@ #include #include -#include "../ir/attr_functor.h" -#include "../relay/analysis/dependency_graph.h" +#include "../../ir/attr_functor.h" +#include "../analysis/dependency_graph.h" #include "doc.h" #include "meta_data.h" -#include "text_printer.h" - -namespace tvm { -class TextPrinter; -} // namespace tvm namespace tvm { namespace relay { +class TextPrinter; + class RelayTextPrinter : public ExprFunctor, public PatternFunctor, public TypeFunctor, @@ -227,14 +224,10 @@ class RelayTextPrinter : public ExprFunctor, DependencyGraph dg_; class AttrPrinter; friend class AttrPrinter; - friend class tvm::TextPrinter; + friend class tvm::relay::TextPrinter; }; -} // namespace relay -} // namespace tvm - -namespace tvm { -namespace tir { +using namespace ::tvm::tir; /*! * \brief Meta node collector @@ -274,7 +267,7 @@ class MetaCollector : public StmtExprVisitor { }; class TIRTextPrinter : public StmtFunctor, - public ExprFunctor, + public tir::ExprFunctor, public TypeFunctor { public: explicit TIRTextPrinter(bool show_meta, TextMetaDataContext* meta) @@ -298,7 +291,7 @@ class TIRTextPrinter : public StmtFunctor, Doc VisitExpr_(const FloatImmNode* op) override; Doc VisitExpr_(const StringImmNode* op) override; Doc VisitExpr_(const CastNode* op) override; - Doc VisitExpr_(const VarNode* op) override; + Doc VisitExpr_(const tir::VarNode* op) override; Doc VisitExpr_(const AddNode* op) override; Doc VisitExpr_(const SubNode* op) override; Doc VisitExpr_(const MulNode* op) override; @@ -323,8 +316,8 @@ class TIRTextPrinter : public StmtFunctor, Doc VisitExpr_(const LoadNode* op) override; Doc VisitExpr_(const RampNode* op) override; Doc VisitExpr_(const BroadcastNode* op) override; - Doc VisitExpr_(const LetNode* op) override; - Doc VisitExpr_(const CallNode* op) override; + Doc VisitExpr_(const tir::LetNode* op) override; + Doc VisitExpr_(const tir::CallNode* op) override; Doc VisitExpr_(const ShuffleNode* op) override; Doc VisitExpr_(const ReduceNode* op) override; Doc VisitExprDefault_(const Object* op) override; @@ -357,7 +350,7 @@ class TIRTextPrinter : public StmtFunctor, /*! \brief meta collector */ MetaCollector meta_collector_; /*! \brief Map from Var to Doc */ - std::unordered_map memo_var_; + std::unordered_map memo_var_; /*! \brief Map from Buffer to Doc */ std::unordered_map memo_buf_; /*! \brief Map from Buffer to Doc */ @@ -365,7 +358,7 @@ class TIRTextPrinter : public StmtFunctor, /*! \brief name allocation map */ std::unordered_map name_alloc_map_; - friend class tvm::TextPrinter; + friend class TextPrinter; Doc VisitType_(const PrimTypeNode* node) override; Doc VisitType_(const PointerTypeNode* node) override; @@ -396,7 +389,7 @@ class TIRTextPrinter : public StmtFunctor, template static Doc PrintConstScalar(DataType dtype, const T& data); Doc GetUniqueName(std::string prefix); - Doc AllocVar(const Var& var); + Doc AllocVar(const tir::Var& var); Doc AllocConst(const AllocateConst& var); Doc AllocBuf(const Buffer& buffer); Doc AllocProducer(const DataProducer& buffer); @@ -412,11 +405,6 @@ class TIRTextPrinter : public StmtFunctor, String AsTVMScriptWithDiagnostic(const ObjectRef& mod, const String& tir_prefix, bool show_meta, runtime::TypedPackedFunc annotate); -} // namespace tir -} // namespace tvm - -namespace tvm { - class TextPrinter { public: explicit TextPrinter(bool show_meta_data, @@ -441,7 +429,7 @@ class TextPrinter { /*! \brief Relay Text Printer */ relay::RelayTextPrinter relay_text_printer_; /*! \brief TIR Text Printer */ - tir::TIRTextPrinter tir_text_printer_; + TIRTextPrinter tir_text_printer_; bool GetVarName(::tvm::tir::Var v, std::string* s) { return tir_text_printer_.GetVarName(v, s); } @@ -472,6 +460,7 @@ class TextPrinter { Doc PrintMod(const IRModule& mod); }; +} // namespace relay } // namespace tvm -#endif // TVM_PRINTER_TEXT_PRINTER_H_ +#endif // TVM_RELAY_PRINTER_TEXT_PRINTER_H_ diff --git a/src/printer/tir_text_printer.cc b/src/relay/printer/tir_text_printer.cc similarity index 97% rename from src/printer/tir_text_printer.cc rename to src/relay/printer/tir_text_printer.cc index 4d74cc6d5a48..eb089bd0d7ed 100644 --- a/src/printer/tir_text_printer.cc +++ b/src/relay/printer/tir_text_printer.cc @@ -36,13 +36,13 @@ #include #include -#include "../tir/transforms/ir_utils.h" +#include "../../tir/transforms/ir_utils.h" #include "doc.h" #include "meta_data.h" #include "text_printer.h" namespace tvm { -namespace tir { +namespace relay { Doc TIRTextPrinter::Print(const ObjectRef& node) { if (!node.defined()) return Doc::Text("(nullptr)"); @@ -93,9 +93,9 @@ Doc TIRTextPrinter::PrintPrimFunc(const PrimFunc& prim_func) { memo_buf_.clear(); // ordered vars associated with buffers, for consistent printing - std::vector buffer_vars_ordered; + std::vector buffer_vars_ordered; - for (Var v : op->params) { + for (tir::Var v : op->params) { auto buffer_map_find = op->buffer_map.find(v); if (buffer_map_find != op->buffer_map.end()) { auto map_data = *buffer_map_find; @@ -132,7 +132,7 @@ Doc TIRTextPrinter::PrintPrimFunc(const PrimFunc& prim_func) { if (memo_buf_.size() != 0) { Doc buffer_doc; std::vector buffer_docs; - for (const Var& v : buffer_vars_ordered) { + for (const tir::Var& v : buffer_vars_ordered) { const Buffer buf = op->buffer_map[v]; buffer_docs.push_back(BufferNode2Doc(buf.get(), Print(buf))); } @@ -144,7 +144,7 @@ Doc TIRTextPrinter::PrintPrimFunc(const PrimFunc& prim_func) { if (op->buffer_map.size() != 0) { // print buffer_map std::vector buffer_map_doc; - for (const Var& v : buffer_vars_ordered) { + for (const tir::Var& v : buffer_vars_ordered) { const Buffer buf = op->buffer_map[v]; buffer_map_doc.push_back(Print(v) << ": " << Print(buf)); } @@ -302,9 +302,9 @@ Doc TIRTextPrinter::VisitExpr_(const CastNode* op) { return doc; } -Doc TIRTextPrinter::VisitExpr_(const VarNode* op) { - const Var& var = GetRef(op); - return meta_->InMeta(var) ? meta_->GetMetaNode(var) : AllocVar(GetRef(op)); +Doc TIRTextPrinter::VisitExpr_(const tir::VarNode* op) { + const tir::Var& var = GetRef(op); + return meta_->InMeta(var) ? meta_->GetMetaNode(var) : AllocVar(GetRef(op)); } #define TVM_DECLARE_TIR_TEXT_PRINTER_BINOP(OpName, OpString) \ @@ -401,13 +401,13 @@ Doc TIRTextPrinter::VisitExpr_(const BroadcastNode* op) { return doc; } -Doc TIRTextPrinter::VisitExpr_(const LetNode* op) { +Doc TIRTextPrinter::VisitExpr_(const tir::LetNode* op) { Doc doc; doc << "let " << Print(op->var) << " = " << Print(op->value) << " in " << Print(op->body); return doc; } -Doc TIRTextPrinter::VisitExpr_(const CallNode* op) { +Doc TIRTextPrinter::VisitExpr_(const tir::CallNode* op) { Doc doc; std::vector func_args; if (auto* ptr_op = op->op.as()) { @@ -771,7 +771,7 @@ Doc TIRTextPrinter::GetUniqueName(std::string prefix) { return Doc::Text(unique_prefix); } -Doc TIRTextPrinter::AllocVar(const Var& var) { +Doc TIRTextPrinter::AllocVar(const tir::Var& var) { const auto& it = memo_var_.find(var); if (it != memo_var_.end()) { return it->second; @@ -831,7 +831,7 @@ Doc TIRTextPrinter::PrintBody(const Stmt& body, bool indent) { return doc; } -bool TIRTextPrinter::GetVarName(Var v, std::string* s) { +bool TIRTextPrinter::GetVarName(tir::Var v, std::string* s) { auto it = memo_var_.find(v); if (it == memo_var_.end()) { return false; @@ -841,5 +841,5 @@ bool TIRTextPrinter::GetVarName(Var v, std::string* s) { return true; } -} // namespace tir +} // namespace relay } // namespace tvm diff --git a/src/printer/tir_text_printer_debug.cc b/src/relay/printer/tir_text_printer_debug.cc similarity index 98% rename from src/printer/tir_text_printer_debug.cc rename to src/relay/printer/tir_text_printer_debug.cc index 6c29558f722c..914d8877d2f7 100644 --- a/src/printer/tir_text_printer_debug.cc +++ b/src/relay/printer/tir_text_printer_debug.cc @@ -29,7 +29,7 @@ #include namespace tvm { -namespace tir { +namespace relay { std::optional span_text(const Span& span) { if (!span.defined()) { @@ -93,5 +93,5 @@ Doc TIRTextPrinterDebug::VisitExpr(const PrimExpr& e) { return TIRTextPrinter::VisitExpr(e); } -} // namespace tir +} // namespace relay } // namespace tvm diff --git a/src/printer/tir_text_printer_debug.h b/src/relay/printer/tir_text_printer_debug.h similarity index 90% rename from src/printer/tir_text_printer_debug.h rename to src/relay/printer/tir_text_printer_debug.h index d0046034cfbf..f7cb7a6554ec 100644 --- a/src/printer/tir_text_printer_debug.h +++ b/src/relay/printer/tir_text_printer_debug.h @@ -23,8 +23,8 @@ * that can be parsed by a parser. */ -#ifndef TVM_PRINTER_TIR_TEXT_PRINTER_DEBUG_H_ -#define TVM_PRINTER_TIR_TEXT_PRINTER_DEBUG_H_ +#ifndef TVM_RELAY_PRINTER_TIR_TEXT_PRINTER_DEBUG_H_ +#define TVM_RELAY_PRINTER_TIR_TEXT_PRINTER_DEBUG_H_ #include #include @@ -32,7 +32,7 @@ #include "text_printer.h" namespace tvm { -namespace tir { +namespace relay { class TIRTextPrinterDebug : public TIRTextPrinter { public: @@ -64,7 +64,7 @@ class TIRTextPrinterDebug : public TIRTextPrinter { std::vector> exprs_by_line_; }; -} // namespace tir +} // namespace relay } // namespace tvm -#endif // TVM_PRINTER_TIR_TEXT_PRINTER_DEBUG_H_ +#endif // TVM_RELAY_PRINTER_TIR_TEXT_PRINTER_DEBUG_H_ diff --git a/src/printer/tvmscript_printer.cc b/src/relay/printer/tvmscript_printer.cc similarity index 96% rename from src/printer/tvmscript_printer.cc rename to src/relay/printer/tvmscript_printer.cc index c578bc53d3d3..096611095097 100644 --- a/src/printer/tvmscript_printer.cc +++ b/src/relay/printer/tvmscript_printer.cc @@ -39,13 +39,15 @@ #include #include -#include "../tir/transforms/ir_utils.h" +#include "../../tir/transforms/ir_utils.h" #include "doc.h" #include "meta_data.h" #include "text_printer.h" namespace tvm { -namespace tir { +namespace relay { + +using namespace tvm::tir; enum class ExprPrecedence : int { /*! \brief Identity(e.g., IntImm, Var) and function call(e.g., floordiv, min) */ @@ -77,14 +79,14 @@ enum class ExprPrecedence : int { */ class BufferUsageFinder : public StmtExprVisitor { public: - static Map> FindUsage(Map> usage, Stmt body) { + static Map> FindUsage(Map> usage, Stmt body) { BufferUsageFinder visitor(std::move(usage)); visitor.VisitStmt(body); return std::move(visitor.usage_); } - void VisitExpr_(const VarNode* op) final { - Var var = GetRef(op); + void VisitExpr_(const tir::VarNode* op) final { + tir::Var var = GetRef(op); if (!usage_.count(var)) { usage_.Set(var, {}); } @@ -107,7 +109,7 @@ class BufferUsageFinder : public StmtExprVisitor { } private: - explicit BufferUsageFinder(Map> usage) : usage_(usage) {} + explicit BufferUsageFinder(Map> usage) : usage_(usage) {} void VisitBuffer(const Buffer& buffer) { if (buffers_visited_.count(buffer.get())) { @@ -124,7 +126,7 @@ class BufferUsageFinder : public StmtExprVisitor { } // The search result. - Map> usage_; + Map> usage_; // The buffers that have been visited so far, to avoid duplicate // entries in the search result. std::unordered_set buffers_visited_; @@ -139,7 +141,7 @@ class BufferUsageFinder : public StmtExprVisitor { * subexpression to decide whether or not parentheses is needed. */ class TVMScriptPrinter : public StmtFunctor, - public ExprFunctor, + public tir::ExprFunctor, public TypeFunctor { public: explicit TVMScriptPrinter(const String& tir_prefix, bool show_meta, @@ -167,20 +169,20 @@ class TVMScriptPrinter : public StmtFunctor, /*! \brief meta data context */ TextMetaDataContext meta_; /*! \brief meta collector */ - MetaCollector meta_collector_; + relay::MetaCollector meta_collector_; /*! \brief map from Function to GlobalVar */ std::unordered_map func2var_; /*! \brief var collector (var defined by For/Loop/Block) */ - std::unordered_set var_not_in_headers_; + std::unordered_set var_not_in_headers_; /*! * \brief buffer collector * (buffer defined in BufferMap, BufferAllocation and MatchBufferRegion) */ std::unordered_set buf_not_in_headers_; /*! \brief Map from Var to thread env name */ - std::unordered_map var_env_map_; + std::unordered_map var_env_map_; /*! \brief Map from Var to Doc */ - std::unordered_map memo_var_; + std::unordered_map memo_var_; /*! \brief Map from Buffer to Doc */ std::unordered_map memo_buf_; /*! \brief Map from Buffer to Declaration Doc */ @@ -194,7 +196,7 @@ class TVMScriptPrinter : public StmtFunctor, /*! \brief loop stack without annotations */ std::vector simple_loop_stack_; /*! \brief the maps from loop_vars to the loops */ - std::unordered_map loop_var_map_; + std::unordered_map loop_var_map_; /*! * \brief simple block vars remap from loop vars * simple_remap requires: @@ -210,12 +212,12 @@ class TVMScriptPrinter : public StmtFunctor, * LetStmt or Allocate that generates their data pointer, rather * than in the header. */ - Map> buffer_var_usage_; + Map> buffer_var_usage_; /*! \brief Analyzer to simplify some expressions. */ arith::Analyzer ana_; Doc VisitExpr_(const CastNode* op, ExprPrecedence* out_precedence) override; - Doc VisitExpr_(const VarNode* op, ExprPrecedence* out_precedence) override; + Doc VisitExpr_(const tir::VarNode* op, ExprPrecedence* out_precedence) override; Doc VisitExpr_(const AddNode* op, ExprPrecedence* out_precedence) override; Doc VisitExpr_(const SubNode* op, ExprPrecedence* out_precedence) override; Doc VisitExpr_(const MulNode* op, ExprPrecedence* out_precedence) override; @@ -243,8 +245,8 @@ class TVMScriptPrinter : public StmtFunctor, Doc VisitExpr_(const LoadNode* op, ExprPrecedence* out_precedence) override; Doc VisitExpr_(const RampNode* op, ExprPrecedence* out_precedence) override; Doc VisitExpr_(const BroadcastNode* op, ExprPrecedence* out_precedence) override; - Doc VisitExpr_(const LetNode* op, ExprPrecedence* out_precedence) override; - Doc VisitExpr_(const CallNode* op, ExprPrecedence* out_precedence) override; + Doc VisitExpr_(const tir::LetNode* op, ExprPrecedence* out_precedence) override; + Doc VisitExpr_(const tir::CallNode* op, ExprPrecedence* out_precedence) override; Doc VisitExpr_(const ShuffleNode* op, ExprPrecedence* out_precedence) override; Doc VisitExpr_(const ReduceNode* op, ExprPrecedence* out_precedence) override; Doc VisitExprDefault_(const Object* op, ExprPrecedence* out_precedence) override; @@ -297,9 +299,9 @@ class TVMScriptPrinter : public StmtFunctor, static Doc PrintString(const StringObj* op) { return Doc::StrLiteral(op->data); } Doc GetUniqueName(std::string prefix); - Doc AllocVar(const Var& var); + Doc AllocVar(const tir::Var& var); Doc AllocBuf(const Buffer& buffer); - void TryDeallocVar(const Var& var); + void TryDeallocVar(const tir::Var& var); bool ContainsOptionalInfo(const Stmt& stmt); /*! * \brief Check if a buffer declaration satisfies: @@ -338,7 +340,9 @@ class TVMScriptPrinter : public StmtFunctor, * \return A boolean indicating whether the input loop depends on previous loops */ bool DependOnPrevLoops(const ForNode* for_op) { - auto f_check = [&var_map = this->loop_var_map_](const VarNode* v) { return var_map.count(v); }; + auto f_check = [&var_map = this->loop_var_map_](const tir::VarNode* v) { + return var_map.count(v); + }; return UsesVar(for_op->min, f_check) || UsesVar(for_op->extent, f_check); } @@ -494,7 +498,7 @@ Doc TVMScriptPrinter::GetUniqueName(std::string prefix) { return Doc::Text(unique_prefix); } -Doc TVMScriptPrinter::AllocVar(const Var& var) { +Doc TVMScriptPrinter::AllocVar(const tir::Var& var) { const auto& it = memo_var_.find(var); if (it != memo_var_.end()) { return it->second; @@ -522,8 +526,8 @@ Doc TVMScriptPrinter::AllocBufferDeclaration(const Buffer& buf) { if (!buf->strides.empty()) { doc << ", strides=" << Print(buf->strides); } - if (buf->elem_offset->IsInstance()) { - Var elem_offset = Downcast(buf->elem_offset); + if (buf->elem_offset->IsInstance()) { + tir::Var elem_offset = Downcast(buf->elem_offset); if (memo_var_.find(elem_offset) != memo_var_.end()) { doc << ", elem_offset=" << Print(buf->elem_offset); } else { @@ -585,7 +589,7 @@ bool TVMScriptPrinter::ContainsOptionalInfo(const Stmt& stmt) { * \brief Try to dealloc vars out of space and leave the index to coming vars. * \note It is not a necessary step. */ -void TVMScriptPrinter::TryDeallocVar(const Var& var) { +void TVMScriptPrinter::TryDeallocVar(const tir::Var& var) { auto it = memo_var_.find(var); ICHECK(it != memo_var_.end()); std::string print_name = it->second.str(); @@ -695,7 +699,7 @@ Doc TVMScriptPrinter::PrintCommReducer(const CommReducerNode* op) { int n_var = static_cast(op->rhs.size()); doc << tir_prefix_ << ".comm_reducer(lambda "; - for (const Var& v_lhs : op->lhs) { + for (const tir::Var& v_lhs : op->lhs) { doc << Print(v_lhs) << ", "; } for (int i = 0; i < n_var; ++i) { @@ -789,10 +793,10 @@ Doc TVMScriptPrinter::VisitExpr_(const CastNode* op, ExprPrecedence* out_precede return doc; } -Doc TVMScriptPrinter::VisitExpr_(const VarNode* op, ExprPrecedence* out_precedence) { +Doc TVMScriptPrinter::VisitExpr_(const tir::VarNode* op, ExprPrecedence* out_precedence) { *out_precedence = ExprPrecedence::kIdentity; - const Var& var = GetRef(op); - return meta_.InMeta(var) ? meta_.GetMetaNode(var) : AllocVar(GetRef(op)); + const tir::Var& var = GetRef(op); + return meta_.InMeta(var) ? meta_.GetMetaNode(var) : AllocVar(GetRef(op)); } bool WillPrintConstScalar(const PrimExpr& expr) { @@ -938,7 +942,7 @@ Doc TVMScriptPrinter::VisitExpr_(const BroadcastNode* op, ExprPrecedence* out_pr return doc; } -Doc TVMScriptPrinter::VisitExpr_(const LetNode* op, ExprPrecedence* out_precedence) { +Doc TVMScriptPrinter::VisitExpr_(const tir::LetNode* op, ExprPrecedence* out_precedence) { *out_precedence = ExprPrecedence::kIdentity; Doc doc; doc << tir_prefix_ << ".let(" << Print(op->var) << ", " << Print(op->value) << ", " @@ -946,7 +950,7 @@ Doc TVMScriptPrinter::VisitExpr_(const LetNode* op, ExprPrecedence* out_preceden return doc; } -Doc TVMScriptPrinter::VisitExpr_(const CallNode* op, ExprPrecedence* out_precedence) { +Doc TVMScriptPrinter::VisitExpr_(const tir::CallNode* op, ExprPrecedence* out_precedence) { *out_precedence = ExprPrecedence::kIdentity; Doc doc; if (auto* ptr_op = op->op.as()) { @@ -1090,7 +1094,7 @@ Doc TVMScriptPrinter::VisitStmt_(const BufferRealizeNode* op) { namespace { bool IsAllocateDeclBufferPattern(const AllocateNode* allocate) { - const Var& buffer_var = allocate->buffer_var; + const tir::Var& buffer_var = allocate->buffer_var; const DeclBufferNode* decl_buffer = allocate->body.as(); if (!decl_buffer) { return false; @@ -1468,8 +1472,8 @@ Doc TVMScriptPrinter::PrintBlockVars(const BlockRealizeNode* op) { auto is_simple_remap = [this, &expr_equal](const IterVar& iter_var, const PrimExpr& value) -> bool { if (iter_var->iter_type != kDataPar && iter_var->iter_type != kCommReduce) return false; - if (!value->IsInstance()) return false; - const Var& var = Downcast(value); + if (!value->IsInstance()) return false; + const tir::Var& var = Downcast(value); auto it = loop_var_map_.find(var.get()); return it != loop_var_map_.end() && expr_equal(it->second->min, iter_var->dom->min) && expr_equal(it->second->extent, iter_var->dom->extent); @@ -1763,7 +1767,7 @@ Doc TVMScriptPrinter::PrintPrimFunc(const PrimFunc& primFunc) { } // print var declaration Doc header_var; - std::vector vars; + std::vector vars; for (const auto& it : memo_var_) { if (var_not_in_headers_.find(it.first.get()) == var_not_in_headers_.end()) { vars.push_back(it.first.get()); @@ -1777,20 +1781,21 @@ Doc TVMScriptPrinter::PrintPrimFunc(const PrimFunc& primFunc) { } } if (!vars.empty()) { - std::sort(vars.begin(), vars.end(), [&](const VarNode* a, const VarNode* b) { - return memo_var_[GetRef(a)].str() < memo_var_[GetRef(b)].str(); + std::sort(vars.begin(), vars.end(), [&](const tir::VarNode* a, const tir::VarNode* b) { + return memo_var_[GetRef(a)].str() < memo_var_[GetRef(b)].str(); }); for (const auto& var : vars) { - auto type = GetRef(var)->type_annotation; + auto type = GetRef(var)->type_annotation; if (auto* ptr_type = type.as()) { auto* prim_type = ptr_type->element_type.as(); ICHECK(prim_type); - header_var << Doc::NewLine() << Print(GetRef(var)) << " = " << tir_prefix_ + header_var << Doc::NewLine() << Print(GetRef(var)) << " = " << tir_prefix_ << ".buffer_var("; header_var << PrintDType(prim_type->dtype) << ", " << Doc::StrLiteral(ptr_type->storage_scope) << ")"; } else { - header_var << Doc::NewLine() << Print(GetRef(var)) << " = " << tir_prefix_ << ".var("; + header_var << Doc::NewLine() << Print(GetRef(var)) << " = " << tir_prefix_ + << ".var("; header_var << PrintDType(var->dtype) << ")"; } } @@ -2013,5 +2018,5 @@ String AsTVMScriptWithDiagnostic(const ObjectRef& mod, const String& tir_prefix, TVM_REGISTER_GLOBAL("script.AsTVMScriptWithDiagnostic").set_body_typed(AsTVMScriptWithDiagnostic); -} // namespace tir +} // namespace relay } // namespace tvm diff --git a/src/relay/transforms/merge_compiler_regions.cc b/src/relay/transforms/merge_compiler_regions.cc index d18c17e63ca1..d70c7480e9e5 100644 --- a/src/relay/transforms/merge_compiler_regions.cc +++ b/src/relay/transforms/merge_compiler_regions.cc @@ -30,9 +30,9 @@ * as external functions. */ -#include #include #include +#include #include #include #include diff --git a/src/relay/transforms/partition_graph.cc b/src/relay/transforms/partition_graph.cc index f6cdf6d1ca18..32ca2878fdc9 100644 --- a/src/relay/transforms/partition_graph.cc +++ b/src/relay/transforms/partition_graph.cc @@ -29,10 +29,10 @@ * external functions, and they will use the provided compiler for codegen. */ -#include #include #include #include +#include #include #include #include diff --git a/src/script/printer/printer.cc b/src/script/printer/printer.cc index 9ebdcb1e99b3..878b380a3717 100644 --- a/src/script/printer/printer.cc +++ b/src/script/printer/printer.cc @@ -23,18 +23,11 @@ namespace tvm { namespace script { namespace printer { -String Script(ObjectRef obj, int indent_spaces, bool print_line_numbers, int num_context_lines, - Optional path_to_underline) { - return DocToPythonScript(IRDocsifier()->AsDoc(obj, ObjectPath::Root()), indent_spaces, - print_line_numbers, num_context_lines, path_to_underline); -} - Default* Default::Instance() { static Default inst; return &inst; } -TVM_REGISTER_GLOBAL("script.printer.Script").set_body_typed(Script); TVM_REGISTER_GLOBAL("script.printer.DefaultIRPrefix") .set_body_typed([](std::string ir, std::string prefix) { Default::Prefix(ir) = prefix; }); TVM_REGISTER_GLOBAL("script.printer.DefaultBufferDType") diff --git a/src/tir/schedule/error.cc b/src/tir/schedule/error.cc index 55d751c3311e..1aae0202ac42 100644 --- a/src/tir/schedule/error.cc +++ b/src/tir/schedule/error.cc @@ -16,7 +16,6 @@ * specific language governing permissions and limitations * under the License. */ -#include "../../printer/text_printer.h" #include "./utils.h" namespace tvm { @@ -52,10 +51,11 @@ String ScheduleError::RenderReport(const String& primitive) const { } return it->second; }); - + const auto* f = runtime::Registry::Get("script.AsTVMScriptWithDiagnostic"); + ICHECK(f != nullptr); os << "ScheduleError: An error occurred in the schedule primitive '" << primitive << "'.\n\nThe IR with diagnostic is:\n" - << AsTVMScriptWithDiagnostic(mod, "T", false, annotate); + << ((*f)(mod, "T", false, annotate).operator String()); // print error message os << "Error message: " << msg; diff --git a/src/tir/transforms/install_debug_spans.cc b/src/tir/transforms/install_debug_spans.cc index bc9002ee841f..c97070e1bf89 100644 --- a/src/tir/transforms/install_debug_spans.cc +++ b/src/tir/transforms/install_debug_spans.cc @@ -30,7 +30,7 @@ #include #include -#include "../../printer/tir_text_printer_debug.h" +#include "../../relay/printer/tir_text_printer_debug.h" namespace tvm { namespace tir { @@ -42,7 +42,7 @@ Stmt DebugInfoInstaller::InstallInfo(const std::string& name, const Stmt& stmt) DebugInfoInstaller::DebugInfoInstaller(const Stmt& stmt, const std::string& filename) { // Determine the line that each stmt/expr will be printed on - tvm::tir::TIRTextPrinterDebug printer(false); + tvm::relay::TIRTextPrinterDebug printer(false); // Fill in the stmts and exprs' line info auto result = printer.Print(stmt).str(); diff --git a/tests/python/relay/test_ir_parser.py b/tests/python/relay/test_ir_parser.py index 5ea6d7e5de6a..08fa01f0b39b 100644 --- a/tests/python/relay/test_ir_parser.py +++ b/tests/python/relay/test_ir_parser.py @@ -14,15 +14,15 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from typing import Union + import numpy as np import pytest - import tvm -import tvm.testing -from tvm import relay import tvm.relay.testing +import tvm.testing from numpy import isclose -from typing import Union +from tvm import relay SEMVER = '#[version = "0.0.5"]\n' @@ -74,7 +74,7 @@ def graph_equal(lhs, rhs): def roundtrip_expr(expr): - text = tvm.relay.Expr.astext(expr, show_meta_data=False) + text = expr.astext() x = tvm.parser.parse_expr(text) assert_graph_equal(x, expr) diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_mlt.py b/tests/python/unittest/test_meta_schedule_schedule_rule_mlt.py index bb9602279404..f40d9427490d 100644 --- a/tests/python/unittest/test_meta_schedule_schedule_rule_mlt.py +++ b/tests/python/unittest/test_meta_schedule_schedule_rule_mlt.py @@ -885,5 +885,4 @@ def max_pool_blocked_compute(height, width, channel): if __name__ == "__main__": - # tvm.testing.main() - test_cache_read_specify_consumer() + tvm.testing.main() diff --git a/tests/python/unittest/test_tir_nodes.py b/tests/python/unittest/test_tir_nodes.py index d4ae84a556d7..2806c7b2fc52 100644 --- a/tests/python/unittest/test_tir_nodes.py +++ b/tests/python/unittest/test_tir_nodes.py @@ -343,7 +343,6 @@ def test_prim_func(): func = tvm.tir.PrimFunc([x, y, b], stmt) # make sure we can print - func.astext() assert func.buffer_map[func.params[2]].same_as(b) assert len(func.buffer_map) == 1 @@ -399,130 +398,5 @@ def test_intimm_cond(): assert x == 1 -def test_block_blockrealize(): - x = tvm.tir.Var("x", "int32") - y = tvm.tir.Var("y", "int32") - vx = tvm.tir.IterVar((16, 16), "vx", 0) - vx_var = vx.var - vy = tvm.tir.IterVar((16, 16), "vy", 2) - vy_var = vy.var - A = tvm.tir.decl_buffer((16), "float32") - B = tvm.tir.decl_buffer((16, 16), "float32") - alloc_buffer = tvm.tir.decl_buffer((16, 16), "float32") - match_buffer = tvm.tir.decl_buffer((16, 16), "float32") - init_body = tvm.tir.BufferStore(A, 0.0, [vx_var]) - body = tvm.tir.BufferStore( - A, - tvm.tir.BufferLoad(A, [vx_var]) + tvm.tir.BufferLoad(B, [vx_var, vy_var]), - [vx_var], - ) - reads = [ - tvm.tir.BufferRegion( - B, [tvm.ir.Range.from_min_extent(vx_var, 1), tvm.ir.Range.from_min_extent(vy_var, 1)] - ) - ] - writes = [tvm.tir.BufferRegion(A, [tvm.ir.Range.from_min_extent(vx_var, 1)])] - block_match_buffer = tvm.tir.MatchBufferRegion( - match_buffer, tvm.tir.BufferRegion(B, [tvm.ir.Range(0, 16), tvm.ir.Range(0, 16)]) - ) - - block = tvm.tir.Block( - [vx, vy], - reads, - writes, - "block", - body, - init=init_body, - alloc_buffers=[alloc_buffer], - match_buffers=[block_match_buffer], - annotations={"attr_key": "attr_value"}, - ) - - # Checking Block - assert isinstance(block, tvm.tir.Block) - # Checking iter_vars - assert block.iter_vars[0] == vx - assert block.iter_vars[1] == vy - # Checking reads/writes region - assert isinstance(block.reads[0], tvm.tir.BufferRegion) - assert block.reads[0].buffer == B - assert block.reads[0].region[0].min == vx_var - assert block.reads[0].region[1].min == vy_var - assert isinstance(block.writes[0], tvm.tir.BufferRegion) - assert block.writes[0].buffer == A - assert block.writes[0].region[0].min == vx_var - assert block.writes[0].region[0].extent == 1 - # Checking name_hint - assert block.name_hint == "block" - # Checking body - assert block.body == body - # Checking init - assert block.init == init_body - # Checking alloc_buffers - assert block.alloc_buffers[0] == alloc_buffer - # Checking match_buffers - assert block.match_buffers[0].buffer == match_buffer - assert isinstance(block.match_buffers[0].source, tvm.tir.BufferRegion) - assert block.match_buffers[0].source.buffer == B - assert block.match_buffers[0].source.region[0].min == 0 - assert block.match_buffers[0].source.region[0].extent == 16 - - # Checking BlockRealize - block_realize = tvm.tir.BlockRealize([x, y], tvm.tir.const(True, "bool"), block) - assert isinstance(block_realize, tvm.tir.BlockRealize) - assert block_realize.iter_values[0] == x - assert block_realize.iter_values[1] == y - assert block_realize.predicate == tvm.tir.const(True, "bool") - assert block_realize.block == block - - # make sure we can print using ReprPrinter - str(block) - str(block_realize) - # make sure we can print using TIRTextPrinter - func = tvm.tir.PrimFunc([], block_realize) - output = func.astext() - assert output.find("meta[tir.BlockRealise]") == -1 - assert output.find("bind") != -1 - assert output.find("reads") != -1 - assert output.find("writes") != -1 - assert output.find("alloc_buffer") != -1 - assert output.find("match_buffer") != -1 - assert output.find("attr") != -1 - assert output.find("with init()") != -1 - - -def test_tir_allocate(): - dtype = "int8" - storage_scope = "global" - ptype = tvm.ir.PointerType(tvm.ir.PrimType(dtype), storage_scope) - a = te.var("buffer", ptype) - allocate = tvm.tir.Allocate( - buffer_var=a, - dtype=dtype, - extents=[2, 2], - condition=tvm.get_global_func("tir.const_true")(dtype, None), - body=tvm.tir.Evaluate(2 + 1), - annotations={ - "attr1": "foo", - "attr2": "bar", - }, - ) - assert allocate.buffer_var == a - assert allocate.dtype == "int8" - assert list(allocate.extents) == [2, 2] - assert allocate.annotations["attr1"] == "foo" - assert allocate.annotations["attr2"] == "bar" - - # make sure we can print using TIRTextPrinter - func = tvm.tir.PrimFunc([], allocate) - output = func.astext() - assert ( - output.find( - 'allocate(buffer: Pointer(global int8), int8, [2, 2]), storage_scope = global, annotations = {"attr2": "bar", "attr1": "foo"})' - ) - != -1 - ) - - if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/unittest/test_tir_transform_lower_warp_memory.py b/tests/python/unittest/test_tir_transform_lower_warp_memory.py index 48af3ebaf529..d4abc26bb204 100644 --- a/tests/python/unittest/test_tir_transform_lower_warp_memory.py +++ b/tests/python/unittest/test_tir_transform_lower_warp_memory.py @@ -14,14 +14,13 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import numpy as np +import pytest import tvm +import tvm.testing from tvm import te from tvm.contrib.nvcc import have_fp16 -import numpy as np -import tvm.testing -import pytest - @tvm.testing.requires_cuda def test_lower_warp_memory_local_scope(): @@ -320,7 +319,7 @@ def test_lower_warp_memory_same_thread(): fdevice = tvm.tir.transform.SplitHostDevice()(mod)["f_kernel0"] mod = tvm.IRModule.from_expr(fdevice) fdevice = tvm.tir.transform.LowerWarpMemory()(mod)["f_kernel0"] - assert "tvm_warp_shuffle" not in fdevice.astext() + assert "tvm_warp_shuffle" not in fdevice.script() @tvm.testing.requires_cuda diff --git a/tests/python/unittest/test_tvmscript_printer_syntax_sugar.py b/tests/python/unittest/test_tvmscript_printer_syntax_sugar.py deleted file mode 100644 index 1bccb8188c9d..000000000000 --- a/tests/python/unittest/test_tvmscript_printer_syntax_sugar.py +++ /dev/null @@ -1,69 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -import pytest -import tvm.testing -from tvm.script.parser import tir as T -from tvm.script import script - - -def _test(obj, expected: str): - assert script(obj).strip() == expected.strip() - - -def test_remap(): - @T.prim_func - def block_with_remap_implicitly(): - for i0, i1, i2, i3, i4, i5 in T.grid(128, 128, 128, 128, 128, 128): - with T.block("update"): - v0 = T.axis.spatial(128, i0 + 1) - v1 = T.axis.spatial(128, i1) - v2 = T.axis.reduce(128, i2) - v3 = T.axis.spatial(128, i3 - 1) - v4 = T.axis.reduce(128, i4) - v5 = T.axis.spatial(128, i5) - pass - - @T.prim_func - def block_with_remap_explicitly(): - for i0, i1, i2, i3, i4, i5 in T.grid(128, 128, 128, 128, 128, 128): - with T.block("update"): - v0 = T.axis.spatial(128, i0 + 1) - v1, v2 = T.axis.remap("SR", [i1, i2]) - v3 = T.axis.spatial(128, i3 - 1) - v4, v5 = T.axis.remap("RS", [i4, i5]) - pass - - expected_output = """@T.prim_func -def main(): - with T.block("root"): - T.reads() - T.writes() - for i0, i1, i2, i3, i4, i5 in T.grid(128, 128, 128, 128, 128, 128): - with T.block("update"): - v0 = T.axis.spatial(128, i0 + 1) - v1, v2 = T.axis.remap("SR", [i1, i2]) - v3 = T.axis.spatial(128, i3 - 1) - v4, v5 = T.axis.remap("RS", [i4, i5]) - T.reads() - T.writes() - T.evaluate(0)""" - _test(block_with_remap_implicitly, expected_output) - _test(block_with_remap_explicitly, expected_output) - - -if __name__ == "__main__": - tvm.testing.main() diff --git a/tests/python/unittest/test_tvmscript_printer_tir.py b/tests/python/unittest/test_tvmscript_printer_tir.py index 9c15fbc88949..d62a1cd12c28 100644 --- a/tests/python/unittest/test_tvmscript_printer_tir.py +++ b/tests/python/unittest/test_tvmscript_printer_tir.py @@ -598,6 +598,47 @@ def test_tuple_type(): _assert_print(obj, "T.Tuple(T.float32, T.int32)") +def test_remap(): + from tvm.script import tir as T + + @T.prim_func + def block_with_remap_implicitly(): + for i0, i1, i2, i3, i4, i5 in T.grid(128, 128, 128, 128, 128, 128): + with T.block("update"): + v0 = T.axis.spatial(128, i0 + 1) + v1 = T.axis.spatial(128, i1) + v2 = T.axis.reduce(128, i2) + v3 = T.axis.spatial(128, i3 - 1) + v4 = T.axis.reduce(128, i4) + v5 = T.axis.spatial(128, i5) + + @T.prim_func + def block_with_remap_explicitly(): + for i0, i1, i2, i3, i4, i5 in T.grid(128, 128, 128, 128, 128, 128): + with T.block("update"): + v0 = T.axis.spatial(128, i0 + 1) + v1, v2 = T.axis.remap("SR", [i1, i2]) + v3 = T.axis.spatial(128, i3 - 1) + v4, v5 = T.axis.remap("RS", [i4, i5]) + + expected_output = """@T.prim_func +def main(): + with T.block("root"): + T.reads() + T.writes() + for i0, i1, i2, i3, i4, i5 in T.grid(128, 128, 128, 128, 128, 128): + with T.block("update"): + v0 = T.axis.spatial(128, i0 + 1) + v1, v2 = T.axis.remap("SR", [i1, i2]) + v3 = T.axis.spatial(128, i3 - 1) + v4, v5 = T.axis.remap("RS", [i4, i5]) + T.reads() + T.writes() + T.evaluate(0)""" + _assert_print(block_with_remap_explicitly, expected_output) + _assert_print(block_with_remap_implicitly, expected_output) + + if __name__ == "__main__": test_prim_func() test_block_realize() @@ -639,3 +680,4 @@ def test_tuple_type(): test_prim_type() test_pointer_type() test_tuple_type() + test_remap()