From d085deef345f84a207c2e8ebb25399c25d5218af Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 3 Apr 2023 12:43:36 -0500 Subject: [PATCH 01/13] [TVMScript][Unittest] Validate round-trip for each TIR lowering step Prior to this PR, some of the IRModule transformations used during lowering can produce TIR that cannot be round-tripped through TVMScript. Since TVMScript is the default method for printing all TIR, this can make it difficult to identify which pass has introduced a breaking change. This PR adds a test that checks whether a module can correctly round-trip from TIR to TVMScript and back, for each lowering pass used in `tvm.build`. --- .../unittest/test_tvmscript_roundtrip.py | 45 +++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py b/tests/python/unittest/test_tvmscript_roundtrip.py index cd7f1726c9d9..e6aa89434857 100644 --- a/tests/python/unittest/test_tvmscript_roundtrip.py +++ b/tests/python/unittest/test_tvmscript_roundtrip.py @@ -22,6 +22,7 @@ import tvm.testing from tvm import tir from tvm.script import tir as T +from tvm.ir.instrument import pass_instrument import numpy as np @@ -181,6 +182,20 @@ def main(inputs: T.Buffer((64, 2, 4), "float32")) -> None: return main +def copy_using_env_thread(): + shape = (64, 2, 4) + + @T.prim_func + def func(A: T.Buffer(shape), B: T.Buffer(shape)): + blocks, M, N = T.meta_var(shape) + + bx = T.launch_thread("blockIdx.x", blocks) + for i, j in T.grid(M, N): + B[bx, i, j] = A[bx, i, j] + + return func + + def opt_gemm_mod_host(): @tvm.script.ir_module class Module: @@ -3772,5 +3787,35 @@ def test_return_none_no_trailing_type(): assert "-> None" not in script +@pass_instrument +class ValidateTVMScriptRoundTrip: + def run_after_pass(self, mod, info): + after_roundtrip = tvm.script.from_source(mod.script(show_meta=True)) + tvm.ir.assert_structural_equal(mod, after_roundtrip, True) + + +@pytest.mark.parametrize( + "generator,target", + [ + (matmul, "llvm"), + pytest.param( + launch_env_thread, + "cuda", + marks=tvm.testing.Feature("cuda").marks(support_required="compile-only"), + ), + pytest.param( + copy_using_env_thread, + "cuda", + marks=tvm.testing.Feature("cuda").marks(support_required="compile-only"), + ), + ], +) +def test_roundtrip_all_lowering_steps(generator, target): + func = generator() + + with tvm.transform.PassContext(instruments=[ValidateTVMScriptRoundTrip()]): + tvm.build(func, target=target) + + if __name__ == "__main__": tvm.testing.main() From aa82b146e2fffc53b188ac8183ab5b0ea1c6e310 Mon Sep 17 00:00:00 2001 From: Siyuan Feng Date: Wed, 8 Feb 2023 22:31:47 +0800 Subject: [PATCH 02/13] [TVMScript] IRModule TVMScript Parser. This PR adds the TVMScript parser/ir_builder support based on the blockbuilder. This commit contains the non-relax portions from https://github.com/apache/tvm/pull/13932. Co-authored-by: Ruihang Lai Co-authored-by: Junru Shao Co-authored-by: Tianqi Chen Co-authored-by: Yuchen Jin Co-authored-by: Steven S. Lyubomirsky Co-authored-by: Yong Wu --- include/tvm/script/ir_builder/ir/frame.h | 11 +++-- include/tvm/script/ir_builder/ir/ir.h | 17 +++++++ python/tvm/script/ir_builder/base.py | 6 ++- python/tvm/script/ir_builder/ir/__init__.py | 2 +- python/tvm/script/ir_builder/ir/ir.py | 45 ++++++++++++++++++ python/tvm/script/parser/core/diagnostics.py | 2 +- python/tvm/script/parser/core/entry.py | 1 + python/tvm/script/parser/core/evaluator.py | 2 +- python/tvm/script/parser/core/parser.py | 50 ++++++++++++++------ python/tvm/script/parser/ir/parser.py | 4 ++ python/tvm/script/parser/tir/entry.py | 4 +- python/tvm/script/parser/tir/parser.py | 26 ++++++++++ src/script/ir_builder/ir/frame.cc | 12 +++-- src/script/ir_builder/ir/ir.cc | 32 ++++++++++++- src/script/ir_builder/ir/utils.h | 49 +++++++++++++++++++ src/script/ir_builder/tir/frame.cc | 15 ++++-- src/script/ir_builder/tir/utils.h | 2 +- 17 files changed, 246 insertions(+), 34 deletions(-) create mode 100644 src/script/ir_builder/ir/utils.h diff --git a/include/tvm/script/ir_builder/ir/frame.h b/include/tvm/script/ir_builder/ir/frame.h index 887981ccffc8..dacfc361a6c7 100644 --- a/include/tvm/script/ir_builder/ir/frame.h +++ b/include/tvm/script/ir_builder/ir/frame.h @@ -38,12 +38,17 @@ namespace ir { */ class IRModuleFrameNode : public IRBuilderFrameNode { public: - Array global_vars; - Array functions; + /*! \brief A map from string names to global variables that ensures global uniqueness. */ + Map global_var_map; + /*! + * \brief A map from GlobalVar to all global functions. + * \note Only defined functions are in the map, while declared functions are not included. + */ + Map functions; void VisitAttrs(tvm::AttrVisitor* v) { IRBuilderFrameNode::VisitAttrs(v); - v->Visit("global_vars", &global_vars); + v->Visit("global_vars", &global_var_map); v->Visit("functions", &functions); } diff --git a/include/tvm/script/ir_builder/ir/ir.h b/include/tvm/script/ir_builder/ir/ir.h index f0e7cc6f5c2f..49bdcf60e6fb 100644 --- a/include/tvm/script/ir_builder/ir/ir.h +++ b/include/tvm/script/ir_builder/ir/ir.h @@ -37,6 +37,23 @@ namespace ir { */ TVM_DLL IRModuleFrame IRModule(); +/*! + * \brief Declare a Function without given the specific function implementation. + * \note It is usually used in cross-function call. And we can specify the function by `DefFunction` + * \param func_name The function unique name. + * \param func_signature A Function w/o body, which used to specify the function signature + * (i.e. func params and func return type/shape). + * \return The corresponding GlobalVar. + */ +TVM_DLL GlobalVar DeclFunction(const String& func_name, const BaseFunc& func_signature); + +/*! + * \brief Define the function which is declared before. + * \param func_name The function unique name. + * \param func The given function implementation + */ +TVM_DLL void DefFunction(const String& func_name, const BaseFunc& func); + } // namespace ir } // namespace ir_builder } // namespace script diff --git a/python/tvm/script/ir_builder/base.py b/python/tvm/script/ir_builder/base.py index 7aa33ee49c72..b35bbd0a7df5 100644 --- a/python/tvm/script/ir_builder/base.py +++ b/python/tvm/script/ir_builder/base.py @@ -64,8 +64,10 @@ def __enter__(self) -> "IRBuilderFrame": _ffi_api.IRBuilderFrameEnter(self) # type: ignore[attr-defined] # pylint: disable=no-member return self - def __exit__(self, ptype, value, trace) -> None: # pylint: disable=unused-argument - _ffi_api.IRBuilderFrameExit(self) # type: ignore[attr-defined] # pylint: disable=no-member + def __exit__(self, exc_type, exc_value, trace) -> None: # pylint: disable=unused-argument + if exc_type is None and exc_value is None: + # Do not execute `FrameExit` if the with scope exits because of exceptions + _ffi_api.IRBuilderFrameExit(self) # type: ignore[attr-defined] # pylint: disable=no-member def add_callback(self, callback: Callable[[], None]) -> None: """Add a callback method invoked when exiting the with-scope. diff --git a/python/tvm/script/ir_builder/ir/__init__.py b/python/tvm/script/ir_builder/ir/__init__.py index ebb9728737ad..946be263a779 100644 --- a/python/tvm/script/ir_builder/ir/__init__.py +++ b/python/tvm/script/ir_builder/ir/__init__.py @@ -16,4 +16,4 @@ # under the License. """Package tvm.script.ir_builder.ir""" from .frame import IRModuleFrame -from .ir import ir_module +from .ir import decl_function, def_function, ir_module diff --git a/python/tvm/script/ir_builder/ir/ir.py b/python/tvm/script/ir_builder/ir/ir.py index 213180463cb2..796d6f3aad04 100644 --- a/python/tvm/script/ir_builder/ir/ir.py +++ b/python/tvm/script/ir_builder/ir/ir.py @@ -16,9 +16,54 @@ # under the License. """Package tvm.script.ir_builder.ir.ir""" +from tvm.ir import BaseFunc, GlobalVar + from . import _ffi_api from .frame import IRModuleFrame def ir_module() -> IRModuleFrame: + """Start a ir_module frame. + Returns + ------- + frame: IRModuleFrame + The constructed frame. + """ return _ffi_api.IRModule() # type: ignore[attr-defined] # pylint: disable=no-member + + +def decl_function(func_name: str, func_signature: BaseFunc) -> GlobalVar: + """Declare a Function without given the specific function implementation. + Parameters + ---------- + func_name : str + The function unique name. + + func_signature: Optional[BaseFunc] + A Function w/o body, which used to specify the function signature + (i.e. func params and func return type/shape). + + Note + ---- + It is usually used in cross-function call. And we can specify the function by `DefFunction` + Returns + ------- + gv : GlobalVar + The corresponding GlobalVar. + """ + + return _ffi_api.DeclFunction( # type: ignore[attr-defined] # pylint: disable=no-member + func_name, func_signature + ) + + +def def_function(func_name: str, func: BaseFunc) -> None: + """Define the function which is declared before. + Parameters + ---------- + func_name : str + The function unique name. + func: BaseFunc + The given function implementation + """ + return _ffi_api.DefFunction(func_name, func) # type: ignore[attr-defined] # pylint: disable=no-member diff --git a/python/tvm/script/parser/core/diagnostics.py b/python/tvm/script/parser/core/diagnostics.py index ad7ae5034780..2767a97f6096 100644 --- a/python/tvm/script/parser/core/diagnostics.py +++ b/python/tvm/script/parser/core/diagnostics.py @@ -220,7 +220,7 @@ def _emit(self, node: doc.AST, message: str, level: diagnostics.DiagnosticLevel) level : diagnostics.DiagnosticLevel The diagnostic level. """ - lineno = node.lineno or self.source.start_line + lineno = node.lineno or 1 col_offset = node.col_offset or self.source.start_column end_lineno = node.end_lineno or lineno end_col_offset = node.end_col_offset or col_offset diff --git a/python/tvm/script/parser/core/entry.py b/python/tvm/script/parser/core/entry.py index 9e6c100c954d..9a2bfdfe486d 100644 --- a/python/tvm/script/parser/core/entry.py +++ b/python/tvm/script/parser/core/entry.py @@ -51,6 +51,7 @@ def parse(program: Union[doc.AST, Any, str], extra_vars: Dict[str, Any] = None) "ir": ir, "T": tir, "tir": tir, + "tvm": tvm, } source = Source(program) diff --git a/python/tvm/script/parser/core/evaluator.py b/python/tvm/script/parser/core/evaluator.py index 3a72a3c33106..075aedd89146 100644 --- a/python/tvm/script/parser/core/evaluator.py +++ b/python/tvm/script/parser/core/evaluator.py @@ -203,7 +203,7 @@ def _visit(self, node: doc.AST) -> Any: else: value = self._eval_expr(node.__class__(**fields)) except Exception as e: # pylint: disable=broad-except,invalid-name - self.parser.report_error(node, str(e)) + self.parser.report_error(node, e) return self._add_intermediate_result(value) def _eval_lambda(self, node: doc.Lambda) -> Any: diff --git a/python/tvm/script/parser/core/parser.py b/python/tvm/script/parser/core/parser.py index 7c699c42aecb..105164ed5ffc 100644 --- a/python/tvm/script/parser/core/parser.py +++ b/python/tvm/script/parser/core/parser.py @@ -60,6 +60,10 @@ def context(): return context() +def _do_nothing(*args, **kwargs): # pylint: disable=unused-argument + pass + + class VarTableFrame: """The variable table frame. A frame of variable table stores the variables created in one block or scope. @@ -259,6 +263,17 @@ def parse(self, extra_vars: Optional[Dict[str, Any]] = None) -> Any: node = self.diag.source.as_ast() self.visit(node) + def get_dispatch_token(self, node: doc.FunctionDef) -> str: + if not isinstance(node, doc.FunctionDef): + self.report_error(node, "Only can get dispatch token for function.") + if not node.decorator_list: + self.report_error(node, "Function must be decorated") + # TODO: only the last decorator is parsed + decorator = self.eval_expr(node.decorator_list[-1]) + if not hasattr(decorator, "dispatch_token"): + self.report_error(node, "The parser does not understand the decorator") + return decorator.dispatch_token + def with_dispatch_token(self, token: str): """Add a new dispatching token as with statement. @@ -388,6 +403,8 @@ def report_error( # Only take the last line of the error message if isinstance(err, TVMError): msg = list(filter(None, str(err).split("\n")))[-1] + elif isinstance(err, KeyError): + msg = "KeyError: " + str(err) else: msg = str(err) self.diag.error(node, msg) @@ -457,30 +474,33 @@ def visit_tvm_annotation(self, node: doc.expr) -> Any: """ return _dispatch(self, "tvm_annotation")(self, node) - def visit_FunctionDef(self, node: doc.FunctionDef) -> Any: # pylint: disable=invalid-name - """The general function definition visiting method. + def visit_FunctionDef(self, node: doc.FunctionDef) -> None: # pylint: disable=invalid-name + """The general function definition visit method. Parameters ---------- node : doc.FunctionDef - The doc AST function definition node. - - Returns - ------- - res : Any - The visiting result. + The doc FunctionDef node. """ - if not node.decorator_list: - self.report_error(node, "Function must be decorated") - # TODO: only the last decorator is parsed - decorator = self.eval_expr(node.decorator_list[-1]) - if not hasattr(decorator, "dispatch_token"): - self.report_error(node, "The parser does not understand the decorator") - token = decorator.dispatch_token + token = self.get_dispatch_token(node) + current_token = self.dispatch_tokens[-1] func = dispatch.get(token=token, type_name="FunctionDef", default=None) if func is None: self.report_error(node, "The parser does not understand the decorator") + pre_func = dispatch.get( + token=current_token, type_name="pre_token_switch", default=_do_nothing + ) + post_func = dispatch.get( + token=current_token, type_name="post_token_switch", default=_do_nothing + ) + pre_func(self, node) _dispatch_wrapper(func)(self, node) + post_func(self, node) + + def visit_tvm_declare_function(self, node: doc.FunctionDef) -> None: + token = self.get_dispatch_token(node) + with self.with_dispatch_token(token): + _dispatch(self, "tvm_declare_function")(self, node) def visit_ClassDef(self, node: doc.ClassDef) -> Any: # pylint: disable=invalid-name """The general class definition visiting method. diff --git a/python/tvm/script/parser/ir/parser.py b/python/tvm/script/parser/ir/parser.py index e0268412d284..13b3e298590f 100644 --- a/python/tvm/script/parser/ir/parser.py +++ b/python/tvm/script/parser/ir/parser.py @@ -32,8 +32,12 @@ def _visit_class_def(self: Parser, node: doc.ClassDef) -> None: node : doc.ClassDef The doc AST class definition node. """ + with self.var_table.with_frame(): with I.ir_module(): + for stmt in node.body: + if isinstance(stmt, doc.FunctionDef): + self.visit_tvm_declare_function(stmt) with self.with_dispatch_token("ir"): self.visit_body(node.body) diff --git a/python/tvm/script/parser/tir/entry.py b/python/tvm/script/parser/tir/entry.py index 411a7f8f3c83..649f817411f0 100644 --- a/python/tvm/script/parser/tir/entry.py +++ b/python/tvm/script/parser/tir/entry.py @@ -83,7 +83,7 @@ def __getitem__(self, keys) -> Buffer: return self(keys) if len(keys) >= 2 and not isinstance(keys[1], str): return self(keys) - return self(*keys) # pylint: disable=no-member # type: ignore + return self(*keys) # type: ignore[attr-defined] # pylint: disable=no-member class PtrProxy: @@ -93,7 +93,7 @@ class PtrProxy: def __call__(self, dtype, storage_scope="global"): if callable(dtype): dtype = dtype().dtype - return ptr(dtype, storage_scope) # pylint: disable=no-member # type: ignore + return ptr(dtype, storage_scope) # type: ignore[attr-defined] # pylint: disable=no-member @deprecated("T.Ptr[...]", "T.handle(...)") def __getitem__(self, keys): diff --git a/python/tvm/script/parser/tir/parser.py b/python/tvm/script/parser/tir/parser.py index 8a067267a352..63171f672289 100644 --- a/python/tvm/script/parser/tir/parser.py +++ b/python/tvm/script/parser/tir/parser.py @@ -24,6 +24,7 @@ from tvm.ir import PrimType from tvm.tir import Buffer, IterVar, PrimExpr, Var +from ...ir_builder import ir as I from ...ir_builder import tir as T from ...ir_builder.base import IRBuilder from ...ir_builder.base import IRBuilderFrame as Frame @@ -473,3 +474,28 @@ def visit_return(self: Parser, node: doc.Return) -> None: The doc AST return node. """ self.report_error(node, "Return is not allowed.") + + +@dispatch.register(token="tir", type_name="tvm_declare_function") +def visit_tvm_declare_function(self: Parser, node: doc.FunctionDef) -> None: + """The function declaration step for tir + + Parameters + ---------- + self : Parser + The visiting parser. + + node : doc.Return + The doc AST return node. + """ + + ret_type = None + if node.returns is not None: + ret_type = self.eval_expr(node.returns) + if callable(ret_type): + ret_type = PrimType(ret_type().dtype) + + # Only ret_type is needed for func_signature. + func_signature = tvm.tir.PrimFunc([], None, ret_type=ret_type) + global_var = I.decl_function(node.name, func_signature) + self.var_table.add(node.name, global_var) diff --git a/src/script/ir_builder/ir/frame.cc b/src/script/ir_builder/ir/frame.cc index a81c56922dff..addf12928435 100644 --- a/src/script/ir_builder/ir/frame.cc +++ b/src/script/ir_builder/ir/frame.cc @@ -26,11 +26,15 @@ namespace ir_builder { namespace ir { void IRModuleFrameNode::ExitWithScope() { - ICHECK_EQ(functions.size(), global_vars.size()); - int n = functions.size(); Map func_map; - for (int i = 0; i < n; ++i) { - func_map.Set(global_vars[i], functions[i]); + CHECK_EQ(functions.size(), global_var_map.size()) + << "All functions must be defined in the IRModule. Got " << global_var_map.size() + << "declared function(s), but only " << functions.size() << "defined function(s)."; + for (const auto& kv : functions) { + const GlobalVar& gv = kv.first; + const BaseFunc& func = kv.second; + CHECK(func.defined()) << "ValueError: function " << gv->name_hint << " is not defined"; + func_map.Set(gv, func); } IRBuilder builder = IRBuilder::Current(); ICHECK(!builder->result.defined()) << "ValueError: Builder.result has already been set"; diff --git a/src/script/ir_builder/ir/ir.cc b/src/script/ir_builder/ir/ir.cc index a8cc452e4f0c..5764e90c8dd4 100644 --- a/src/script/ir_builder/ir/ir.cc +++ b/src/script/ir_builder/ir/ir.cc @@ -20,6 +20,8 @@ #include #include +#include "./utils.h" + namespace tvm { namespace script { namespace ir_builder { @@ -27,12 +29,40 @@ namespace ir { IRModuleFrame IRModule() { ObjectPtr n = make_object(); - n->global_vars.clear(); + n->global_var_map.clear(); n->functions.clear(); return IRModuleFrame(n); } +GlobalVar DeclFunction(const String& func_name, const BaseFunc& func_signature) { + IRModuleFrame frame = FindModuleFrame("I.DeclFunction"); + CHECK(!frame->global_var_map.count(func_name)) + << "ValueError: function " << func_name << " already exists"; + GlobalVar gv = GlobalVar(func_name); + CHECK(frame->functions.find(gv) == frame->functions.end()) + << "ValueError: function " << func_name << " has already been defined."; + frame->global_var_map.Set(func_name, gv); + if (func_signature.defined()) { + frame->functions.Set(gv, func_signature); + } + return gv; +} + +void DefFunction(const String& func_name, const BaseFunc& func) { + IRModuleFrame frame = FindModuleFrame("I.DefFunction"); + auto it = frame->global_var_map.find(func_name); + CHECK(it != frame->global_var_map.end()) + << "ValueError: function " << func_name << " does not exist, please declare it first."; + const GlobalVar& gv = (*it).second; + frame->functions.Set(gv, func); + if (func->checked_type_.defined()) { + gv->checked_type_ = func->checked_type_; + } +} + TVM_REGISTER_GLOBAL("script.ir_builder.ir.IRModule").set_body_typed(IRModule); +TVM_REGISTER_GLOBAL("script.ir_builder.ir.DeclFunction").set_body_typed(DeclFunction); +TVM_REGISTER_GLOBAL("script.ir_builder.ir.DefFunction").set_body_typed(DefFunction); } // namespace ir } // namespace ir_builder diff --git a/src/script/ir_builder/ir/utils.h b/src/script/ir_builder/ir/utils.h new file mode 100644 index 000000000000..58d5e53f7032 --- /dev/null +++ b/src/script/ir_builder/ir/utils.h @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_SCRIPT_IR_BUILDER_IR_UTILS_H_ +#define TVM_SCRIPT_IR_BUILDER_IR_UTILS_H_ + +#include + +namespace tvm { +namespace script { +namespace ir_builder { +namespace ir { + +inline IRModuleFrame FindModuleFrame(const String& method) { + IRBuilder builder = IRBuilder::Current(); + if (Optional frame = builder->FindFrame()) { + const Optional& last_module_frame = builder->GetLastFrame(); + if (last_module_frame.defined() && last_module_frame.value() == frame) { + return frame.value(); + } + } else { + LOG(FATAL) << "ValueError: IRModule frame not find. Please ensure '" << method + << "' is called under I.ir_module()"; + } + LOG(FATAL) << "ValueError: '" << method << "' must be called immediately under I.ir_module()"; + throw; +} + +} // namespace ir +} // namespace ir_builder +} // namespace script +} // namespace tvm + +#endif // TVM_SCRIPT_IR_BUILDER_IR_UTILS_H_ diff --git a/src/script/ir_builder/tir/frame.cc b/src/script/ir_builder/tir/frame.cc index 1e63201a40dd..dd8d3c2ed3f3 100644 --- a/src/script/ir_builder/tir/frame.cc +++ b/src/script/ir_builder/tir/frame.cc @@ -16,6 +16,7 @@ * specific language governing permissions and limitations * under the License. */ +#include #include #include @@ -41,9 +42,17 @@ void PrimFuncFrameNode::ExitWithScope() { ICHECK(!builder->result.defined()) << "ValueError: Builder.result has already been set"; builder->result = func; } else if (Optional opt_frame = builder->FindFrame()) { - ir::IRModuleFrame frame = opt_frame.value(); - frame->global_vars.push_back(GlobalVar(name.value_or(""))); - frame->functions.push_back(func); + CHECK(name.defined()) << "ValueError: The function name must be defined before exiting the " + "function scope, if it's defined in a Module"; + const ir::IRModuleFrame& frame = opt_frame.value(); + const String& func_name = name.value_or(""); + if (!frame->global_var_map.count(func_name)) { + // Case. First time visiting the function. + ir::DeclFunction(func_name, func); + } + // Define the function. + // Note we do checks to disallow redefinition of functions inside the `DefFunction`. + ir::DefFunction(func_name, func); } else { LOG(FATAL) << "ValueError: Cannot find where to insert PrimFunc"; } diff --git a/src/script/ir_builder/tir/utils.h b/src/script/ir_builder/tir/utils.h index 7ccc132fa1fe..f3b547532cfd 100644 --- a/src/script/ir_builder/tir/utils.h +++ b/src/script/ir_builder/tir/utils.h @@ -87,7 +87,7 @@ inline PrimFuncFrame FindPrimFuncFrame(const String& method) { * \return The top frame of BlockFrame. */ inline BlockFrame FindBlockFrame(const String& method) { - if (Optional frame = IRBuilder::Current()->GetLastFrame()) { + if (Optional frame = IRBuilder::Current()->FindFrame()) { return frame.value(); } else if (Optional frame = IRBuilder::Current()->FindFrame()) { LOG(FATAL) << "ValueError: " << method << " must be called at the top of a T.block(). " From 3c70c2ea776930aa5bb58e97e6cf5ea60626f360 Mon Sep 17 00:00:00 2001 From: Hongyi Jin <3231950289@qq.com> Date: Sun, 26 Feb 2023 11:05:47 -0500 Subject: [PATCH 03/13] [TVMScript] Expose IRModule::attrs as I.module_attrs This is an upstreaming of the non-relax portions of https://github.com/apache/tvm/pull/14132, including a unit test specically to validate `I.module_attrs`. --- include/tvm/script/ir_builder/base.h | 2 ++ include/tvm/script/ir_builder/ir/frame.h | 3 +++ python/tvm/ir/module.py | 14 ++++++++++++-- python/tvm/script/ir_builder/base.py | 11 +++++++++++ python/tvm/script/ir_builder/ir/__init__.py | 7 ++++++- python/tvm/script/ir_builder/ir/ir.py | 14 ++++++++++++++ python/tvm/script/parser/ir/__init__.py | 4 ++-- python/tvm/script/parser/ir/parser.py | 11 +++++++++-- src/ir/module.cc | 6 ++---- src/script/ir_builder/base.cc | 6 ++++++ src/script/ir_builder/ir/frame.cc | 3 ++- src/script/ir_builder/ir/ir.cc | 12 ++++++++++++ src/script/printer/ir/ir.cc | 5 +++++ tests/python/unittest/test_tvmscript_roundtrip.py | 14 ++++++++++++++ 14 files changed, 100 insertions(+), 12 deletions(-) diff --git a/include/tvm/script/ir_builder/base.h b/include/tvm/script/ir_builder/base.h index 61ca3eb9f7eb..a00ea5768e23 100644 --- a/include/tvm/script/ir_builder/base.h +++ b/include/tvm/script/ir_builder/base.h @@ -237,6 +237,8 @@ class IRBuilder : public runtime::ObjectRef { * \sa tvm::support::With */ static IRBuilder Current(); + /*! \brief See if the current thread-local scope has an IRBuilder. */ + static bool IsInScope(); /*! * \brief Give a string name to the `obj` * \tparam TObjectRef The type of the object to name. diff --git a/include/tvm/script/ir_builder/ir/frame.h b/include/tvm/script/ir_builder/ir/frame.h index dacfc361a6c7..ed425cf61441 100644 --- a/include/tvm/script/ir_builder/ir/frame.h +++ b/include/tvm/script/ir_builder/ir/frame.h @@ -45,11 +45,14 @@ class IRModuleFrameNode : public IRBuilderFrameNode { * \note Only defined functions are in the map, while declared functions are not included. */ Map functions; + /*! \brief IRModule's attributes. */ + Map attrs; void VisitAttrs(tvm::AttrVisitor* v) { IRBuilderFrameNode::VisitAttrs(v); v->Visit("global_vars", &global_var_map); v->Visit("functions", &functions); + v->Visit("attrs", &attrs); } static constexpr const char* _type_key = "script.ir_builder.IRModuleFrame"; diff --git a/python/tvm/ir/module.py b/python/tvm/ir/module.py index 3daffb2640c5..232c70aa93d8 100644 --- a/python/tvm/ir/module.py +++ b/python/tvm/ir/module.py @@ -37,7 +37,7 @@ class IRModule(Node, Scriptable): Map of global var to BaseFunc """ - def __init__(self, functions=None, type_definitions=None): + def __init__(self, functions=None, type_definitions=None, attrs=None): if functions is None: functions = {} elif isinstance(functions, dict): @@ -60,7 +60,17 @@ def __init__(self, functions=None, type_definitions=None): raise TypeError("Expect type_definitions to be Dict[GlobalTypeVar, Type]") mapped_type_defs[k] = v type_definitions = mapped_type_defs - self.__init_handle_by_constructor__(_ffi_api.IRModule, functions, type_definitions) + + attrs = None if not attrs else attrs + if attrs is not None: + attrs = ast.literal_eval(str(attrs)) + attrs = tvm.ir.make_node("DictAttrs", **attrs) + self.__init_handle_by_constructor__( + _ffi_api.IRModule, + functions, + type_definitions, + attrs, + ) def __setitem__(self, var, val): """Add a mapping to the module. diff --git a/python/tvm/script/ir_builder/base.py b/python/tvm/script/ir_builder/base.py index b35bbd0a7df5..1d5d050444f7 100644 --- a/python/tvm/script/ir_builder/base.py +++ b/python/tvm/script/ir_builder/base.py @@ -138,6 +138,17 @@ def current() -> "IRBuilder": """ return _ffi_api.IRBuilderCurrent() # type: ignore[attr-defined] # pylint: disable=no-member + @staticmethod + def is_in_scope() -> bool: + """See if the current thread-local scope has an IRBuilder. + + Returns + ------- + bool + Whether the current thread-local scope has an IRBuilder + """ + return _ffi_api.IRBuilderIsInScope() # type: ignore[attr-defined] # pylint: disable=no-member + def get(self) -> _Object: """Get the constructed IR.""" return _ffi_api.IRBuilderGet(self) # type: ignore[attr-defined] # pylint: disable=no-member diff --git a/python/tvm/script/ir_builder/ir/__init__.py b/python/tvm/script/ir_builder/ir/__init__.py index 946be263a779..b796de8113f3 100644 --- a/python/tvm/script/ir_builder/ir/__init__.py +++ b/python/tvm/script/ir_builder/ir/__init__.py @@ -16,4 +16,9 @@ # under the License. """Package tvm.script.ir_builder.ir""" from .frame import IRModuleFrame -from .ir import decl_function, def_function, ir_module +from .ir import ( + decl_function, + def_function, + ir_module, + module_attrs, +) diff --git a/python/tvm/script/ir_builder/ir/ir.py b/python/tvm/script/ir_builder/ir/ir.py index 796d6f3aad04..eabbd188d063 100644 --- a/python/tvm/script/ir_builder/ir/ir.py +++ b/python/tvm/script/ir_builder/ir/ir.py @@ -16,6 +16,10 @@ # under the License. """Package tvm.script.ir_builder.ir.ir""" +from typing import Dict, List + +from tvm.runtime import Object as tvm_Object + from tvm.ir import BaseFunc, GlobalVar from . import _ffi_api @@ -67,3 +71,13 @@ def def_function(func_name: str, func: BaseFunc) -> None: The given function implementation """ return _ffi_api.DefFunction(func_name, func) # type: ignore[attr-defined] # pylint: disable=no-member + + +def module_attrs(attrs: Dict[str, tvm_Object]) -> None: + """Specify the attrs of the ir_module frame. + Parameters + ---------- + attrs: Dict[str, Object] + The module attrs. + """ + return _ffi_api.ModuleAttrs(attrs) # type: ignore[attr-defined] # pylint: disable=no-member diff --git a/python/tvm/script/parser/ir/__init__.py b/python/tvm/script/parser/ir/__init__.py index fedd2f0a14a8..f8c9d4f0afc9 100644 --- a/python/tvm/script/parser/ir/__init__.py +++ b/python/tvm/script/parser/ir/__init__.py @@ -15,8 +15,8 @@ # specific language governing permissions and limitations # under the License. """The ir module parser""" - +from ...ir_builder.ir import * # pylint: disable=redefined-builtin from . import parser as _parser from .entry import ir_module -__all__ = ["ir_module"] +__all__ = ["ir_module", "module_attrs", "module_global_infos", "dummy_global_info"] diff --git a/python/tvm/script/parser/ir/parser.py b/python/tvm/script/parser/ir/parser.py index 13b3e298590f..201c99074f20 100644 --- a/python/tvm/script/parser/ir/parser.py +++ b/python/tvm/script/parser/ir/parser.py @@ -35,11 +35,17 @@ def _visit_class_def(self: Parser, node: doc.ClassDef) -> None: with self.var_table.with_frame(): with I.ir_module(): + with self.with_dispatch_token("ir"): + for stmt in node.body: + if not isinstance(stmt, doc.FunctionDef): + self.visit(stmt) for stmt in node.body: if isinstance(stmt, doc.FunctionDef): self.visit_tvm_declare_function(stmt) with self.with_dispatch_token("ir"): - self.visit_body(node.body) + for stmt in node.body: + if isinstance(stmt, doc.FunctionDef): + self.visit(stmt) @dispatch.register(token="ir", type_name="Assign") @@ -57,7 +63,7 @@ def _visit_assign(_self: Parser, _node: doc.Assign) -> None: @dispatch.register(token="ir", type_name="Expr") -def _visit_expr(_self: Parser, _node: doc.Expr) -> None: +def _visit_expr(self: Parser, node: doc.Expr) -> None: """The expression visiting method for ir module. Parameters @@ -68,6 +74,7 @@ def _visit_expr(_self: Parser, _node: doc.Expr) -> None: node : doc.ClassDef The doc AST expression node. """ + self.eval_expr(node.value) @dispatch.register(token="default", type_name="Assign") diff --git a/src/ir/module.cc b/src/ir/module.cc index 7a973da29dfa..9d663f9a1a2f 100644 --- a/src/ir/module.cc +++ b/src/ir/module.cc @@ -382,10 +382,8 @@ IRModule IRModule::FromText(const String& text, const String& source_path) { TVM_REGISTER_NODE_TYPE(IRModuleNode); TVM_REGISTER_GLOBAL("ir.IRModule") - .set_body_typed([](tvm::Map funcs, - tvm::Map types) { - return IRModule(funcs, types, {}); - }); + .set_body_typed([](tvm::Map funcs, tvm::Map types, + tvm::DictAttrs attrs) { return IRModule(funcs, types, {}, {}, attrs); }); TVM_REGISTER_GLOBAL("ir.Module_Add") .set_body_typed([](IRModule mod, GlobalVar var, ObjectRef val, bool update) -> IRModule { diff --git a/src/script/ir_builder/base.cc b/src/script/ir_builder/base.cc index 8303efff4f20..879db4f3d713 100644 --- a/src/script/ir_builder/base.cc +++ b/src/script/ir_builder/base.cc @@ -77,6 +77,11 @@ IRBuilder IRBuilder::Current() { return stack->back(); } +bool IRBuilder::IsInScope() { + std::vector* stack = ThreadLocalBuilderStack(); + return !stack->empty(); +} + namespace details { Namer::FType& Namer::vtable() { @@ -106,6 +111,7 @@ TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilder").set_body_typed([]() { return TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderEnter").set_body_method(&IRBuilder::EnterWithScope); TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderExit").set_body_method(&IRBuilder::ExitWithScope); TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderCurrent").set_body_typed(IRBuilder::Current); +TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderIsInScope").set_body_typed(IRBuilder::IsInScope); TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderGet") .set_body_method(&IRBuilderNode::Get); TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderName").set_body_typed(IRBuilder::Name); diff --git a/src/script/ir_builder/ir/frame.cc b/src/script/ir_builder/ir/frame.cc index addf12928435..92470ec65342 100644 --- a/src/script/ir_builder/ir/frame.cc +++ b/src/script/ir_builder/ir/frame.cc @@ -38,7 +38,8 @@ void IRModuleFrameNode::ExitWithScope() { } IRBuilder builder = IRBuilder::Current(); ICHECK(!builder->result.defined()) << "ValueError: Builder.result has already been set"; - builder->result = tvm::IRModule(func_map); + auto dict_attrs = attrs.empty() ? NullValue() : DictAttrs(attrs); + builder->result = tvm::IRModule(func_map, {}, {}, {}, dict_attrs); } TVM_REGISTER_NODE_TYPE(IRModuleFrameNode); diff --git a/src/script/ir_builder/ir/ir.cc b/src/script/ir_builder/ir/ir.cc index 5764e90c8dd4..0c34f85246c9 100644 --- a/src/script/ir_builder/ir/ir.cc +++ b/src/script/ir_builder/ir/ir.cc @@ -60,9 +60,21 @@ void DefFunction(const String& func_name, const BaseFunc& func) { } } +void ModuleAttrs(Map attrs) { + if (IRBuilder::IsInScope()) { + // TODO(hongyi): add comments to explain why we need to check if the module frame is in scope + IRModuleFrame frame = FindModuleFrame("I.ModuleAttr"); + if (!frame->attrs.empty()) { + LOG(FATAL) << "ValueError: Duplicate module attrs, previous one is:\n" << frame->attrs; + } + frame->attrs = attrs; + } +} + TVM_REGISTER_GLOBAL("script.ir_builder.ir.IRModule").set_body_typed(IRModule); TVM_REGISTER_GLOBAL("script.ir_builder.ir.DeclFunction").set_body_typed(DeclFunction); TVM_REGISTER_GLOBAL("script.ir_builder.ir.DefFunction").set_body_typed(DefFunction); +TVM_REGISTER_GLOBAL("script.ir_builder.ir.ModuleAttrs").set_body_typed(ModuleAttrs); } // namespace ir } // namespace ir_builder diff --git a/src/script/printer/ir/ir.cc b/src/script/printer/ir/ir.cc index 065cfe5168ad..1c751d40f2e7 100644 --- a/src/script/printer/ir/ir.cc +++ b/src/script/printer/ir/ir.cc @@ -64,6 +64,11 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) std::sort(functions.begin(), functions.end()); With f(d); (*f)->AddDispatchToken(d, "ir"); + if (mod->attrs.defined() && !mod->attrs->dict.empty()) { + (*f)->stmts.push_back( + ExprStmtDoc(IR(d, "module_attrs") // + ->Call({d->AsDoc(mod->attrs, p->Attr("attrs"))}))); + } for (const auto& entry : functions) { const GlobalVar& gv = entry.gv; const BaseFunc& func = entry.func; diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py b/tests/python/unittest/test_tvmscript_roundtrip.py index e6aa89434857..1e4006f0e5f3 100644 --- a/tests/python/unittest/test_tvmscript_roundtrip.py +++ b/tests/python/unittest/test_tvmscript_roundtrip.py @@ -3707,6 +3707,19 @@ def func( return func +def ir_module_with_attrs(): + @I.ir_module + class Module: + I.module_attrs({"attr": 10}) + + @T.prim_func + def tir_func(A: T.Buffer(16, "int32"), B: T.Buffer(16, "int32")): + for i in range(16): + B[i] = A[i] + + return Module + + ir_generator = tvm.testing.parameter( launch_env_thread, opt_gemm_normalize, @@ -3772,6 +3785,7 @@ def func( merge_shape_var_def, if_then_else_var, tvm_shfl_builtins, + ir_module_with_attrs, ) From f72b29c567fb8caeabf5962cdb489a6aa7acc512 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 31 Mar 2023 12:43:21 -0500 Subject: [PATCH 04/13] [TVMScript] Distinguish between void* and handle --- include/tvm/script/ir_builder/tir/ir.h | 4 ++-- python/tvm/script/ir_builder/tir/ir.py | 9 +++++++-- .../unittest/test_tvmscript_roundtrip.py | 19 +++++++++++++++++++ 3 files changed, 28 insertions(+), 4 deletions(-) diff --git a/include/tvm/script/ir_builder/tir/ir.h b/include/tvm/script/ir_builder/tir/ir.h index a0343b03955b..73219f0302d9 100644 --- a/include/tvm/script/ir_builder/tir/ir.h +++ b/include/tvm/script/ir_builder/tir/ir.h @@ -435,9 +435,9 @@ void Evaluate(PrimExpr value); */ inline Var Handle(runtime::DataType dtype = runtime::DataType::Void(), // String storage_scope = "global", // - bool is_size_var = false) { + bool is_size_var = false, bool is_unknown_type = false) { Type type_annotation{nullptr}; - if (dtype.is_void() && storage_scope == "global") { + if (is_unknown_type && storage_scope == "global") { type_annotation = PrimType(runtime::DataType::Handle()); } else { type_annotation = PointerType(PrimType(dtype), storage_scope); diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index c3ced1e0338b..a10f4bc6746a 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -1441,7 +1441,9 @@ def boolean(expr: Optional[PrimExpr] = None, is_size_var: bool = False) -> PrimE return _ffi_api.Boolean(expr, is_size_var) # type: ignore[attr-defined] # pylint: disable=no-member -def handle(dtype: str = "void", storage_scope: str = "global", *, is_size_var: bool = False) -> Var: +def handle( + dtype: Optional[str] = None, storage_scope: str = "global", *, is_size_var: bool = False +) -> Var: """Create a TIR var that represents a pointer. Parameters @@ -1460,7 +1462,10 @@ def handle(dtype: str = "void", storage_scope: str = "global", *, is_size_var: b res : PrimExpr The new tir.Var with type handle or casted expression with type handle. """ - return _ffi_api.Handle(dtype, storage_scope, is_size_var) # type: ignore[attr-defined] # pylint: disable=no-member + is_unknown_type = dtype is None + if dtype is None: + dtype = "void" + return _ffi_api.Handle(dtype, storage_scope, is_size_var, is_unknown_type) # type: ignore[attr-defined] # pylint: disable=no-member def void(expr: Optional[PrimExpr] = None, *, is_size_var: bool = False) -> PrimExpr: diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py b/tests/python/unittest/test_tvmscript_roundtrip.py index 1e4006f0e5f3..2ceedebb3c7a 100644 --- a/tests/python/unittest/test_tvmscript_roundtrip.py +++ b/tests/python/unittest/test_tvmscript_roundtrip.py @@ -3347,6 +3347,25 @@ def func(): return func +def test_void_ptr_vs_handle(): + """Distinguish between void* and handle + + In the future, perhaps these should be de-duplicated by forbidding + one of the two C++ representations. + """ + # Generates PointerType(PrimType(DataType::Void())) + @T.prim_func + def void_ptr(out_ret_value: T.handle("void")): + T.evaluate(out_ret_value) + + # Generates PrimType(DataType::Handle()) + @T.prim_func + def handle(out_ret_value: T.handle): + T.evaluate(out_ret_value) + + assert not tvm.ir.structural_equal(void_ptr, handle) + + def void_ptr(): @T.prim_func def func(out_ret_value: T.handle("void")): From 41802a344cedbe864725895c0cc33f7ec51605fb Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 3 Apr 2023 11:04:30 -0500 Subject: [PATCH 05/13] [TIR] Use same DataType of builtin::tvm_struct_set in C++ and Python Prior to this commit, the python API `tvm.tir.op.tvm_struct_set` defined the return type of `builtin::tvm_struct_set` as `"handle"`, while the C++ API `tvm::tir::TVMStructSet` defined the return type as `DataType::Int(32)`. The data type used for this builtin has no effect, because no value is returned. However, this discrepancy can cause failure to roundtrip through TVMScript. This commit updates the Python API to use `"int32"`, for consistency with the C++ API and with `CodeGenCPU`. --- python/tvm/tir/op.py | 2 +- .../unittest/test_tvmscript_roundtrip.py | 36 ++++++++++++++++++- 2 files changed, 36 insertions(+), 2 deletions(-) diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py index 0fe460c085d7..419ab2275858 100644 --- a/python/tvm/tir/op.py +++ b/python/tvm/tir/op.py @@ -527,7 +527,7 @@ def tvm_struct_set(arr, index, field, value): call : PrimExpr The call expression. """ - return call_intrin("handle", "tir.tvm_struct_set", arr, index, field, value) + return call_intrin("int32", "tir.tvm_struct_set", arr, index, field, value) def address_of(buffer_load, span=None): diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py b/tests/python/unittest/test_tvmscript_roundtrip.py index 2ceedebb3c7a..0f44a85f35d9 100644 --- a/tests/python/unittest/test_tvmscript_roundtrip.py +++ b/tests/python/unittest/test_tvmscript_roundtrip.py @@ -21,7 +21,7 @@ import tvm import tvm.testing from tvm import tir -from tvm.script import tir as T +from tvm.script import tir as T, ir as I from tvm.ir.instrument import pass_instrument import numpy as np @@ -3739,6 +3739,39 @@ def tir_func(A: T.Buffer(16, "int32"), B: T.Buffer(16, "int32")): return Module +def tvm_struct_set_generated_in_cpp(): + """Ensure same dtype for tvm_struct_set in Python/C++ + + The TVMStructSet method in C++, used internally by + LowerTVMBuiltin, and the Python method `T.tvm_struct_set`, used + when parsing TVMScript should use the same dtype "int32". + """ + + @I.ir_module + class Module: + @T.prim_func + def tir_packed_call(A: T.Buffer(16)): + T.attr(0, "device_id", 0) + T.attr(0, "device_type", 0) + T.evaluate( + T.tvm_call_cpacked( + "tvm_test_cpacked", + T.tvm_stack_make_array( + A.data, + T.tvm_stack_make_shape(16, dtype="handle"), + T.reinterpret(T.uint64(0), dtype="handle"), + T.uint32(1), + T.Cast("float32", 0), + 0, + dtype="handle", + ), + dtype="int32", + ) + ) + + return tvm.tir.transform.LowerTVMBuiltin()(Module) + + ir_generator = tvm.testing.parameter( launch_env_thread, opt_gemm_normalize, @@ -3805,6 +3838,7 @@ def tir_func(A: T.Buffer(16, "int32"), B: T.Buffer(16, "int32")): if_then_else_var, tvm_shfl_builtins, ir_module_with_attrs, + tvm_struct_set_generated_in_cpp, ) From 2d9c7bf23c1994ce8b73e5020b38dd56063be674 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 3 Apr 2023 11:13:03 -0500 Subject: [PATCH 06/13] [TIR] Improved SeqStmt::Flatten utility Prior to this commit, `SeqStmt::Flatten` could accept an arbitrary number of arguments, where each argument was of type `const tir::Stmt&` or an iterable. However, if `SeqStmt::Flatten` were passed a subclass of `tir::Stmt`, the templated overload was selected as the better match. This commit rewrites `SeqStmt::Flatten` using C++17's `"constexpr if"` feature, to handle cases of `SeqStmt`, superclasses of `SeqStmt`, and other subclasses of `Stmt`. --- include/tvm/tir/stmt.h | 36 +++++++++++++++++++++++++----------- 1 file changed, 25 insertions(+), 11 deletions(-) diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index 9ed9973871d9..51c99fb375f3 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -734,19 +734,33 @@ class SeqStmt : public Stmt { public: explicit Flattener(Array* seq) : seq_(seq) {} - void operator()(size_t i, const Stmt& stmt) const { - if (!stmt.defined()) return; - if (auto* op = stmt.as()) { - operator()(0, op->seq); - } else { - seq_->push_back(stmt); + template + void operator()(size_t i, const T& stmt_or_seq) const { + if constexpr (std::is_base_of_v) { + // Early bail-out, applicable to any ObjectRef + if (!stmt_or_seq.defined()) return; } - } - template - void operator()(size_t i, const T& seq) const { - for (auto v : seq) { - this->operator()(0, v); + if constexpr (std::is_same_v) { + // No need for dynamic type-checking if the static type is a + // SeqStmt. + (*this)(0, stmt_or_seq->seq); + } else if constexpr (std::is_base_of_v) { + // Dynamic type-checking for a SeqStmt that could be + // flattened. + if (auto* op = stmt_or_seq.template as()) { + operator()(0, op->seq); + } else { + seq_->push_back(stmt_or_seq); + } + } else if constexpr (std::is_base_of_v) { + // Any other Stmt type just gets appended. + seq_->push_back(stmt_or_seq); + } else { + // Anything else is treated as an iterable of Stmt. + for (auto v : stmt_or_seq) { + this->operator()(0, v); + } } } From 3266f53594c34980f3bc48b9d948528244d74d7e Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 3 Apr 2023 12:24:57 -0500 Subject: [PATCH 07/13] [TIR] Use String instead of StringImm for AttrStmtNode::node `tir::StringImm` can round-trip through TVMScript when used in a context that requires a PrimExpr, such as the arguments of a `tir::Call`. However, contexts that only require a `ObjectRef`, such as the `AttrStmtNode::node`, use the same TVMScript representation as `"string_value"`, but are parsed `tvm::String` instances. This commit updates `MakePackedAPI` to use `String` instead of `StringImm` in its default value for `AttrStmtNode::node`. --- src/tir/transforms/make_packed_api.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tir/transforms/make_packed_api.cc b/src/tir/transforms/make_packed_api.cc index b12d3dd49f05..ed3a2da19613 100644 --- a/src/tir/transforms/make_packed_api.cc +++ b/src/tir/transforms/make_packed_api.cc @@ -266,7 +266,7 @@ PrimFunc MakePackedAPI(PrimFunc&& func) { StringImm(name_hint + "_compute_"), body); // Set device context if (vmap.count(device_id.get())) { - PrimExpr node = StringImm("default"); + ObjectRef node = String("default"); seq_check.push_back(AttrStmt(node, attr::device_id, device_id, nop)); seq_check.push_back(AttrStmt(node, attr::device_type, device_type, nop)); From 5fe1fee2281a1c8212dca0dcb012a1e8f5e9d8b1 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 3 Apr 2023 12:29:50 -0500 Subject: [PATCH 08/13] [TIR] Flatten SeqStmt on constructions Previously, SeqStmt could be nested, making a distinction between the nested `SeqStmt({SeqStmt({a,b}), c})` and the flat `SeqStmt({a,b,c})`, even though the two are semantically equivalent. This also caused an issue with round-trips through TVMScript, which does not preserve this distinction. This commit updates the `SeqStmt` constructor and the `SeqStmt` visitor in `StmtMutator` to flatten nested sequential statements provided. --- src/tir/ir/stmt.cc | 12 +++++++++ src/tir/ir/stmt_functor.cc | 4 +-- .../unittest/test_tvmscript_roundtrip.py | 25 +++++++++++++++++++ 3 files changed, 38 insertions(+), 3 deletions(-) diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index e4569898f7ed..3b07bd943973 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -387,6 +387,18 @@ TVM_REGISTER_NODE_TYPE(PrefetchNode); // SeqStmt SeqStmt::SeqStmt(Array seq, Span span) { + bool requires_flattening = std::any_of( + seq.begin(), seq.end(), [](const Stmt& stmt) { return stmt->IsInstance(); }); + + if (requires_flattening) { + auto flattened = SeqStmt::Flatten(seq); + if (auto* ptr = flattened.as()) { + seq = ptr->seq; + } else { + seq = {flattened}; + } + } + auto node = make_object(); node->seq = std::move(seq); node->span = std::move(span); diff --git a/src/tir/ir/stmt_functor.cc b/src/tir/ir/stmt_functor.cc index f5063b222b9b..7c693b7efcf7 100644 --- a/src/tir/ir/stmt_functor.cc +++ b/src/tir/ir/stmt_functor.cc @@ -439,9 +439,7 @@ Stmt StmtMutator::VisitStmt_(const SeqStmtNode* op) { if (seq.same_as(op->seq)) { return GetRef(op); } else { - auto n = CopyOnWrite(op); - n->seq = std::move(seq); - return Stmt(n); + return SeqStmt(seq); } } diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py b/tests/python/unittest/test_tvmscript_roundtrip.py index 0f44a85f35d9..c02fcb852342 100644 --- a/tests/python/unittest/test_tvmscript_roundtrip.py +++ b/tests/python/unittest/test_tvmscript_roundtrip.py @@ -3772,6 +3772,30 @@ def tir_packed_call(A: T.Buffer(16)): return tvm.tir.transform.LowerTVMBuiltin()(Module) +def nested_seqstmt(): + """Nested SeqStmt should be normalized to flat SeqStmt + + Nested SeqStmt are representable in the TIR structures, but are + flattened when converted to TVMScript. Previously, this could + cause failures to round-trip through TVMScript, including + erroneous use of TVMScript's concise-scoping rules. This was + resolved by normalizing nested SeqStmt in TIR, such that the use + of `tir.SeqStmt` below results in a single flat `tir.SeqStmt` + containing the three `tir.Evaluate` calls. + """ + func = tvm.tir.PrimFunc( + params=[], + body=tvm.tir.SeqStmt( + [ + tvm.tir.SeqStmt([tvm.tir.Evaluate(0), tvm.tir.Evaluate(1)]), + tvm.tir.Evaluate(2), + ] + ), + ) + + return func + + ir_generator = tvm.testing.parameter( launch_env_thread, opt_gemm_normalize, @@ -3839,6 +3863,7 @@ def tir_packed_call(A: T.Buffer(16)): tvm_shfl_builtins, ir_module_with_attrs, tvm_struct_set_generated_in_cpp, + nested_seqstmt, ) From b835a9ef07723c3210ee0e36538c699dbe467b02 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 3 Apr 2023 12:35:20 -0500 Subject: [PATCH 09/13] [Node] Utility methods for ObjectPathPair handling This commit adds a templated overload to `SEqualReducer::operator()` that accepts a lambda function to update the path of the LHS and RHS of the comparison. ```c++ // Usage prior to this utility function if (equal.IsPathTracingEnabled()) { const ObjectPathPair& self_paths = equal.GetCurrentObjectPaths(); ObjectPathPair attr_paths = {self_paths->lhs_path->Attr("value"), self_paths->rhs_path->Attr("value")}; if (!equal(kv.second, other->Lookup(kv.first->name_hint), attr_paths)) return false; } else { if (!equal(kv.second, other->Lookup(kv.first->name_hint))) return false; } // Usage after this utility function if (!equal(kv.second, other->Lookup(kv.first->name_hint), [](const auto& path) { return path->Attr("value"); })) { return false; } ``` --- include/tvm/node/structural_equal.h | 44 +++++++++++++++----- src/ir/module.cc | 64 ++++++++++++++--------------- src/node/structural_equal.cc | 51 ++++++++++++++++------- 3 files changed, 101 insertions(+), 58 deletions(-) diff --git a/include/tvm/node/structural_equal.h b/include/tvm/node/structural_equal.h index 5bd76404a998..cff1e775072f 100644 --- a/include/tvm/node/structural_equal.h +++ b/include/tvm/node/structural_equal.h @@ -195,20 +195,40 @@ class SEqualReducer { * \param rhs The right operand. * \return the immediate check result. */ - bool operator()(const double& lhs, const double& rhs) const; - bool operator()(const int64_t& lhs, const int64_t& rhs) const; - bool operator()(const uint64_t& lhs, const uint64_t& rhs) const; - bool operator()(const int& lhs, const int& rhs) const; - bool operator()(const bool& lhs, const bool& rhs) const; - bool operator()(const std::string& lhs, const std::string& rhs) const; - bool operator()(const DataType& lhs, const DataType& rhs) const; + bool operator()(const double& lhs, const double& rhs, + Optional paths = NullOpt) const; + bool operator()(const int64_t& lhs, const int64_t& rhs, + Optional paths = NullOpt) const; + bool operator()(const uint64_t& lhs, const uint64_t& rhs, + Optional paths = NullOpt) const; + bool operator()(const int& lhs, const int& rhs, Optional paths = NullOpt) const; + bool operator()(const bool& lhs, const bool& rhs, Optional paths = NullOpt) const; + bool operator()(const std::string& lhs, const std::string& rhs, + Optional paths = NullOpt) const; + bool operator()(const DataType& lhs, const DataType& rhs, + Optional paths = NullOpt) const; template ::value>::type> - bool operator()(const ENum& lhs, const ENum& rhs) const { + bool operator()(const ENum& lhs, const ENum& rhs, + Optional paths = NullOpt) const { using Underlying = typename std::underlying_type::type; static_assert(std::is_same::value, "Enum must have `int` as the underlying type"); - return EnumAttrsEqual(static_cast(lhs), static_cast(rhs), &lhs, &rhs); + return EnumAttrsEqual(static_cast(lhs), static_cast(rhs), &lhs, &rhs, paths); + } + + template , ObjectPath>>> + bool operator()(const T& lhs, const T& rhs, const Callable& callable) { + if (IsPathTracingEnabled()) { + ObjectPathPair current_paths = GetCurrentObjectPaths(); + ObjectPathPair new_paths = {callable(current_paths->lhs_path), + callable(current_paths->rhs_path)}; + return (*this)(lhs, rhs, new_paths); + } else { + return (*this)(lhs, rhs); + } } /*! @@ -310,7 +330,8 @@ class SEqualReducer { void RecordMismatchPaths(const ObjectPathPair& paths) const; private: - bool EnumAttrsEqual(int lhs, int rhs, const void* lhs_address, const void* rhs_address) const; + bool EnumAttrsEqual(int lhs, int rhs, const void* lhs_address, const void* rhs_address, + Optional paths = NullOpt) const; bool ObjectAttrsEqual(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars, const ObjectPathPair* paths) const; @@ -321,7 +342,8 @@ class SEqualReducer { template static bool CompareAttributeValues(const T& lhs, const T& rhs, - const PathTracingData* tracing_data); + const PathTracingData* tracing_data, + Optional paths = NullOpt); /*! \brief Internal class pointer. */ Handler* handler_ = nullptr; diff --git a/src/ir/module.cc b/src/ir/module.cc index 9d663f9a1a2f..ba66a6689422 100644 --- a/src/ir/module.cc +++ b/src/ir/module.cc @@ -63,46 +63,46 @@ IRModule::IRModule(tvm::Map functions, } bool IRModuleNode::SEqualReduce(const IRModuleNode* other, SEqualReducer equal) const { - if (!equal(this->attrs, other->attrs)) return false; + if (!equal(this->attrs, other->attrs, [](const auto& path) { return path->Attr("attrs"); })) { + return false; + } + + if (equal.IsPathTracingEnabled()) { + if ((functions.size() != other->functions.size()) || + (type_definitions.size() != other->type_definitions.size())) { + return false; + } + } - if (functions.size() != other->functions.size()) return false; - // Update GlobalVar remap + // Define remaps for GlobalVar and GlobalTypeVar based on their + // string name. Early bail-out is only performed when path-tracing + // is disabled, as the later equality checks on the member variables + // will provide better error messages. for (const auto& gv : this->GetGlobalVars()) { - if (!other->ContainGlobalVar(gv->name_hint)) return false; - if (!equal.DefEqual(gv, other->GetGlobalVar(gv->name_hint))) return false; + if (other->ContainGlobalVar(gv->name_hint)) { + if (!equal.DefEqual(gv, other->GetGlobalVar(gv->name_hint))) return false; + } else if (!equal.IsPathTracingEnabled()) { + return false; + } } - // Checking functions - for (const auto& kv : this->functions) { - if (equal.IsPathTracingEnabled()) { - const ObjectPathPair& obj_path_pair = equal.GetCurrentObjectPaths(); - ObjectPathPair func_paths = {obj_path_pair->lhs_path->Attr("functions")->MapValue(kv.first), - obj_path_pair->rhs_path->Attr("functions") - ->MapValue(other->GetGlobalVar(kv.first->name_hint))}; - if (!equal(kv.second, other->Lookup(kv.first->name_hint), func_paths)) return false; - } else { - if (!equal(kv.second, other->Lookup(kv.first->name_hint))) return false; + for (const auto& gtv : this->GetGlobalTypeVars()) { + if (other->ContainGlobalTypeVar(gtv->name_hint)) { + if (!equal.DefEqual(gtv, other->GetGlobalTypeVar(gtv->name_hint))) return false; + } else if (!equal.IsPathTracingEnabled()) { + return false; } } - if (type_definitions.size() != other->type_definitions.size()) return false; - // Update GlobalTypeVar remap - for (const auto& gtv : this->GetGlobalTypeVars()) { - if (!other->ContainGlobalTypeVar(gtv->name_hint)) return false; - if (!equal.DefEqual(gtv, other->GetGlobalTypeVar(gtv->name_hint))) return false; + // Checking functions and type definitions + if (!equal(this->functions, other->functions, + [](const auto& path) { return path->Attr("functions"); })) { + return false; } - // Checking type_definitions - for (const auto& kv : this->type_definitions) { - if (equal.IsPathTracingEnabled()) { - const ObjectPathPair& obj_path_pair = equal.GetCurrentObjectPaths(); - ObjectPathPair type_paths = { - obj_path_pair->lhs_path->Attr("type_definitions")->MapValue(kv.first), - obj_path_pair->rhs_path->Attr("type_definitions") - ->MapValue(other->GetGlobalTypeVar(kv.first->name_hint))}; - if (!equal(kv.second, other->LookupTypeDef(kv.first->name_hint), type_paths)) return false; - } else { - if (!equal(kv.second, other->LookupTypeDef(kv.first->name_hint))) return false; - } + if (!equal(this->type_definitions, other->type_definitions, + [](const auto& path) { return path->Attr("type_definitions"); })) { + return false; } + return true; } diff --git a/src/node/structural_equal.cc b/src/node/structural_equal.cc index 42726af9859a..66a347f6b8ba 100644 --- a/src/node/structural_equal.cc +++ b/src/node/structural_equal.cc @@ -109,51 +109,72 @@ bool SEqualReducer::DefEqual(const ObjectRef& lhs, const ObjectRef& rhs) { template /* static */ bool SEqualReducer::CompareAttributeValues(const T& lhs, const T& rhs, - const PathTracingData* tracing_data) { + const PathTracingData* tracing_data, + Optional paths) { if (BaseValueEqual()(lhs, rhs)) { return true; - } else { - GetPathsFromAttrAddressesAndStoreMismatch(&lhs, &rhs, tracing_data); - return false; } + + if (tracing_data && !tracing_data->first_mismatch->defined()) { + if (paths) { + *tracing_data->first_mismatch = paths.value(); + } else { + GetPathsFromAttrAddressesAndStoreMismatch(&lhs, &rhs, tracing_data); + } + } + return false; } -bool SEqualReducer::operator()(const double& lhs, const double& rhs) const { +bool SEqualReducer::operator()(const double& lhs, const double& rhs, + Optional paths) const { return CompareAttributeValues(lhs, rhs, tracing_data_); } -bool SEqualReducer::operator()(const int64_t& lhs, const int64_t& rhs) const { +bool SEqualReducer::operator()(const int64_t& lhs, const int64_t& rhs, + Optional paths) const { return CompareAttributeValues(lhs, rhs, tracing_data_); } -bool SEqualReducer::operator()(const uint64_t& lhs, const uint64_t& rhs) const { +bool SEqualReducer::operator()(const uint64_t& lhs, const uint64_t& rhs, + Optional paths) const { return CompareAttributeValues(lhs, rhs, tracing_data_); } -bool SEqualReducer::operator()(const int& lhs, const int& rhs) const { +bool SEqualReducer::operator()(const int& lhs, const int& rhs, + Optional paths) const { return CompareAttributeValues(lhs, rhs, tracing_data_); } -bool SEqualReducer::operator()(const bool& lhs, const bool& rhs) const { +bool SEqualReducer::operator()(const bool& lhs, const bool& rhs, + Optional paths) const { return CompareAttributeValues(lhs, rhs, tracing_data_); } -bool SEqualReducer::operator()(const std::string& lhs, const std::string& rhs) const { +bool SEqualReducer::operator()(const std::string& lhs, const std::string& rhs, + Optional paths) const { return CompareAttributeValues(lhs, rhs, tracing_data_); } -bool SEqualReducer::operator()(const DataType& lhs, const DataType& rhs) const { +bool SEqualReducer::operator()(const DataType& lhs, const DataType& rhs, + Optional paths) const { return CompareAttributeValues(lhs, rhs, tracing_data_); } bool SEqualReducer::EnumAttrsEqual(int lhs, int rhs, const void* lhs_address, - const void* rhs_address) const { + const void* rhs_address, Optional paths) const { if (lhs == rhs) { return true; - } else { - GetPathsFromAttrAddressesAndStoreMismatch(lhs_address, rhs_address, tracing_data_); - return false; } + + if (tracing_data_ && !tracing_data_->first_mismatch->defined()) { + if (paths) { + *tracing_data_->first_mismatch = paths.value(); + } else { + GetPathsFromAttrAddressesAndStoreMismatch(&lhs, &rhs, tracing_data_); + } + } + + return false; } const ObjectPathPair& SEqualReducer::GetCurrentObjectPaths() const { From df93de1ac197167136575565b8ae6ac53bfdef65 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 3 Apr 2023 14:29:06 -0500 Subject: [PATCH 10/13] [TIR] Use IRModuleNode::Remove to remove None in PrimFuncPass Prior to this commit, `PrimFuncPass` directly removed empty `PrimFunc` objects from the module's `Map functions`. Because it didn't update the `global_var_map_` as well, these two maps could become out of sync. Since the `global_var_map_` is checked as part of `StructuralEqual()`, but isn't displayed when printing to TVMScript, this can result in identical printouts being flagged as non-identical. This commit updates `PrimFuncPass` to call the `IRModuleNode::Remove` method, which updates both the `functions` and `global_var_map_` variables. --- src/tir/ir/transform.cc | 11 ++++--- .../unittest/test_tir_transform_helpers.py | 30 ++++++++++++++++++- 2 files changed, 36 insertions(+), 5 deletions(-) diff --git a/src/tir/ir/transform.cc b/src/tir/ir/transform.cc index 4c59a1767372..38ec16ea3755 100644 --- a/src/tir/ir/transform.cc +++ b/src/tir/ir/transform.cc @@ -88,7 +88,8 @@ PrimFuncPass::PrimFuncPass( // Perform Module -> Module optimizations at the PrimFunc level. IRModule PrimFuncPassNode::operator()(IRModule mod, const PassContext& pass_ctx) const { ICHECK(mod.defined()); - std::vector deleted_list; + std::vector deleted_list; + IRModuleNode* mod_ptr = mod.CopyOnWrite(); auto* func_dict = mod_ptr->functions.CopyOnWrite(); // directly loop over the underlying dict @@ -101,14 +102,16 @@ IRModule PrimFuncPassNode::operator()(IRModule mod, const PassContext& pass_ctx) kv.second = std::move(func); if (!kv.second.defined()) { - deleted_list.push_back(kv.first); + deleted_list.push_back(Downcast(kv.first)); } } } - // automatic removal of None + // Automatic removal of None. This uses IRModuleNode::Remove + // instead of manipulating func_dict directly, to ensure that both + // the function map and the global_var_map_ are correctly updated. for (const auto& gv : deleted_list) { - func_dict->erase(gv); + mod_ptr->Remove(gv); } return mod; } diff --git a/tests/python/unittest/test_tir_transform_helpers.py b/tests/python/unittest/test_tir_transform_helpers.py index f8dc0f682d06..657bda591ae2 100644 --- a/tests/python/unittest/test_tir_transform_helpers.py +++ b/tests/python/unittest/test_tir_transform_helpers.py @@ -17,7 +17,7 @@ import pytest import tvm -from tvm.script import tir as T +from tvm.script import tir as T, ir as I import tvm.testing @@ -119,5 +119,33 @@ def checker_filter_out_both(func: tvm.tir.PrimFunc): assert len(after.functions) == 0 +class TestFilterRemovesGlobalVarMap(tvm.testing.CompareBeforeAfter): + """Filtering out a function should be identical to never adding it + + This test is to guard against hidden state in the IRModule that + remains after filtering. Previously, this was observed in the + `IRModuleNode::global_var_map_`, which retained entries of + filtered-out functions. + """ + + transform = tvm.tir.transform.Filter(lambda prim_func: False) + + def before(self): + @I.ir_module + class module: + @T.prim_func + def func(): + T.evaluate(0) + + return module + + def expected(self): + @I.ir_module + class module: + pass + + return module + + if __name__ == "__main__": tvm.testing.main() From 904f4d4a9ebf3a1b319c0a7efd1729219d2f35c8 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 3 Apr 2023 16:18:05 -0500 Subject: [PATCH 11/13] [TIR] Merged kDeviceThreadAxis and kUseDynamicSharedMemoryTag Previously, `kDeviceThreadAxis` defined the IterVar to be used for each thread/block axis, and `kUseDynamicSharedMemoryTag` defined whether dynamic memory allocations exist, which are primarily used to produce a list of strings by `tvm::codegen::ExtractFuncInfo`. Because `kDeviceThreadAxis` is a `Array`, the IterVar is used prior to its definition site at `tir::attr::thread_extent`, which results in errors when attempting to round-trip through TVMScript. This commit replaces these attributes with `attr::kKernelLaunchParams`, which directly contains the kernel launch parameters. These are expressed as an `Array`, allowing the generated TVMScript to successfully round-trip. --- include/tvm/tir/function.h | 46 ++++++++++++++++++------- src/target/build_common.h | 12 ++----- src/target/source/codegen_metal.cc | 22 ++++-------- src/tir/transforms/remap_thread_axis.cc | 27 ++++++++------- src/tir/transforms/split_host_device.cc | 21 +++++++---- 5 files changed, 73 insertions(+), 55 deletions(-) diff --git a/include/tvm/tir/function.h b/include/tvm/tir/function.h index 48328263fb55..2cb1269d010f 100644 --- a/include/tvm/tir/function.h +++ b/include/tvm/tir/function.h @@ -272,10 +272,11 @@ PrimFunc Specialize(PrimFunc func, const Map& param_map); * \sa tvm::attr */ namespace attr { + /*! * \brief List of thread IterVar that a DeviceLaunch function corresponds to. * - * Type: Array + * Type: Array * * We call a device kernel launch function f using the following convention: * @@ -283,23 +284,42 @@ namespace attr { * [arg1, arg2, ..., arg_n, * work_size_1, work_size_2, ... work_size_m, dyn_shmem_size]) * - * Here n = len(arg), m = len(work_size) = len(device_thread_axis). + * Here n = len(arg), m = len(work_size) = len(launch_params)-1. * - * When kDeviceUseDynSharedMemory is not set, dyn_shmem_size argument is omitted. + * The list of kernel launch params indicates which additional + * parameters will be provided to the PackedFunc by the calling + * scope. * - * The list of device_thread_axis indicates how can be bind the - * work_size arguments to the corresponding threads. + * - "threadIdx.x", "threadIdx.y", "threadIdx.z" * - * \sa tvm::CallingConv::kDeviceKernelLaunch - */ -constexpr const char* kDeviceThreadAxis = "tir.device_thread_axis"; - -/*! - * \brief Whether or not use dynamic shared memory. + * The extent of the thread count in x/y/z, to be used when + * launching the compute kernel on the device. For example, the + * gridDimX/Y/Z parameters passed to cuLaunchKernel when launching a + * CUDA kernel, or the groupCountX/Y/Z parameters passed to + * vkCmdDispatch when dispatching a compute pipeline to Vulkan. * - * Type: Integer + * - "blockIdx.x", "blockIdx.y", "blockIdx.z" + * + * The extent of the block iterators, to be used when launching the + * compute kernel on the device. For example, the blockDimX/Y/Z + * parameters passed to cuLaunchKernel when launching a CUDA kernel. + * For runtimes that do not require the block to be provided + * externally, this parameter is ignored. For example, the + * spv::ExecutionModeLocalSize for SPIR-V shaders on Vulkan, where + * this parameter is defined in the shader. + * + * - tvm::runtime::launch_param::kUseDynamicSharedMemoryTag + * + * The size of the shared memory that may be allocated internally by + * the kernel. For example, exposed as the + * CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES attribute in + * cuda. + * + * Defined as "tir.use_dyn_shared_memory". + * + * \sa tvm::CallingConv::kDeviceKernelLaunch */ -constexpr const char* kDeviceUseDynSharedMemory = "tir.device_use_dyn_shared_memory"; +constexpr const char* kKernelLaunchParams = "tir.kernel_launch_params"; /*! * \brief Whether to set noalias rule on the function arguments. diff --git a/src/target/build_common.h b/src/target/build_common.h index 35b3d92eb814..7c9ad8cb3c68 100644 --- a/src/target/build_common.h +++ b/src/target/build_common.h @@ -50,15 +50,9 @@ inline std::unordered_map ExtractFuncInfo(co for (size_t i = 0; i < f->params.size(); ++i) { info.arg_types.push_back(f->params[i].dtype()); } - if (auto opt = f->GetAttr>(tir::attr::kDeviceThreadAxis)) { - auto thread_axis = opt.value(); - for (size_t i = 0; i < thread_axis.size(); ++i) { - info.launch_param_tags.push_back(thread_axis[i]->thread_tag); - } - } - if (auto opt = f->GetAttr(tir::attr::kDeviceUseDynSharedMemory)) { - if (opt.value().IntValue() != 0) { - info.launch_param_tags.push_back(runtime::launch_param::kUseDynamicSharedMemoryTag); + if (auto opt = f->GetAttr>(tir::attr::kKernelLaunchParams)) { + for (const auto& tag : opt.value()) { + info.launch_param_tags.push_back(tag); } } auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); diff --git a/src/target/source/codegen_metal.cc b/src/target/source/codegen_metal.cc index 534e2c3654c4..36ef44bc4814 100644 --- a/src/target/source/codegen_metal.cc +++ b/src/target/source/codegen_metal.cc @@ -130,12 +130,14 @@ void CodeGenMetal::AddFunction(const PrimFunc& f) { ICHECK_EQ(name_supply_->FreshName("threadIdx"), "threadIdx"); ICHECK_EQ(name_supply_->FreshName("blockIdx"), "blockIdx"); int work_dim = 0; - auto thread_axis = f->GetAttr>(tir::attr::kDeviceThreadAxis).value(); - - for (IterVar iv : thread_axis) { - runtime::ThreadScope scope = runtime::ThreadScope::Create(iv->thread_tag); - work_dim = std::max(work_dim, scope.dim_index + 1); + auto launch_params = f->GetAttr>(tir::attr::kKernelLaunchParams).value(); + for (const auto& tag : launch_params) { + if (tag != runtime::launch_param::kUseDynamicSharedMemoryTag) { + runtime::ThreadScope scope = runtime::ThreadScope::Create(tag); + work_dim = std::max(work_dim, scope.dim_index + 1); + } } + if (work_dim != 0) { // use ushort by default for now stream << " "; @@ -145,16 +147,6 @@ void CodeGenMetal::AddFunction(const PrimFunc& f) { PrintType(DataType::UInt(thread_index_bits_, work_dim), stream); stream << " threadIdx [[thread_position_in_threadgroup]]\n"; } - // bind thread axis - for (IterVar iv : thread_axis) { - ICHECK(!var_idmap_.count(iv->var.get())); - std::string vname = iv->thread_tag; - if (work_dim <= 1) { - vname = vname.substr(0, iv->thread_tag.length() - 2); - } - var_idmap_[iv->var.get()] = - CastFromTo(vname, DataType::UInt(thread_index_bits_), iv->var.dtype()); - } // the function scope. stream << ") {\n"; int func_scope = this->BeginScope(); diff --git a/src/tir/transforms/remap_thread_axis.cc b/src/tir/transforms/remap_thread_axis.cc index e101e6b904ce..519a3e1f80d8 100644 --- a/src/tir/transforms/remap_thread_axis.cc +++ b/src/tir/transforms/remap_thread_axis.cc @@ -69,26 +69,29 @@ class ThreadAxisRewriter : private StmtExprMutator { std::unordered_map vmap_; }; -PrimFunc RemapThreadAxis(PrimFunc&& f, Map thread_map) { +PrimFunc RemapThreadAxis(PrimFunc func, Map thread_map) { std::unordered_map tmap; for (const auto& kv : thread_map) { tmap[kv.first] = kv.second; } - auto opt_thread_axis = f->GetAttr>(tir::attr::kDeviceThreadAxis); - ICHECK(opt_thread_axis != nullptr) << "Require attribute " << tir::attr::kDeviceThreadAxis; - auto thread_axis = opt_thread_axis.value(); - auto* n = f.CopyOnWrite(); - - // replace the thread axis - for (size_t i = 0; i < thread_axis.size(); ++i) { - auto it = tmap.find(thread_axis[i]->thread_tag); - if (it != tmap.end()) { - thread_axis.Set(i, it->second); + if (auto opt = func->GetAttr>(tir::attr::kKernelLaunchParams)) { + ICHECK(opt != nullptr) << "Require attribute " << tir::attr::kKernelLaunchParams; + auto launch_params = opt.value(); + // replace the thread axis attribute + for (size_t i = 0; i < launch_params.size(); ++i) { + auto it = tmap.find(launch_params[i]->thread_tag); + if (it != tmap.end()) { + launch_params.Set(i, it->second); + } } + + func = WithAttr(std::move(func), tir::attr::kKernelLaunchParams, launch_params); } + + auto* n = func.CopyOnWrite(); n->body = ThreadAxisRewriter(tmap).Rewrite(std::move(n->body)); - return WithAttr(std::move(f), tir::attr::kDeviceThreadAxis, thread_axis); + return func; } namespace transform { diff --git a/src/tir/transforms/split_host_device.cc b/src/tir/transforms/split_host_device.cc index c43fc403ed94..3696ff84e5b8 100644 --- a/src/tir/transforms/split_host_device.cc +++ b/src/tir/transforms/split_host_device.cc @@ -51,6 +51,17 @@ class DeviceInfoCollector : public StmtVisitor { PrimExpr dyn_shmem_size_{0}; bool use_dyn_shmem_{false}; + Array GetLaunchParams() const { + Array output; + for (const auto& axis : thread_axis_) { + output.push_back(axis->thread_tag); + } + if (use_dyn_shmem_) { + output.push_back(runtime::launch_param::kUseDynamicSharedMemoryTag); + } + return output; + } + private: void VisitStmt_(const AttrStmtNode* op) final { if (op->attr_key == attr::thread_extent) { @@ -199,8 +210,9 @@ class HostDeviceSplitter : public StmtMutator { GlobalVar kernel_symbol_global = global_var_supply->FreshGlobal(kernel_symbol, false); PrimFunc device_func(params, Substitute(body, remap_vars)); - device_func = - WithAttr(std::move(device_func), tir::attr::kDeviceThreadAxis, dev_info.thread_axis_); + device_func = WithAttr(std::move(device_func), tir::attr::kKernelLaunchParams, + dev_info.GetLaunchParams()); + device_func = WithAttr(std::move(device_func), tvm::attr::kCallingConv, Integer(CallingConv::kDeviceKernelLaunch)); device_func = WithAttr(std::move(device_func), tvm::attr::kGlobalSymbol, @@ -208,10 +220,7 @@ class HostDeviceSplitter : public StmtMutator { device_func = WithAttr(std::move(device_func), tir::attr::kNoAlias, Integer(1)); device_func = WithAttr(std::move(device_func), tvm::attr::kTarget, device_target_); device_func = WithAttr(std::move(device_func), tir::attr::kIsGlobalFunc, Integer(1)); - if (dev_info.use_dyn_shmem_) { - device_func = - WithAttr(std::move(device_func), tir::attr::kDeviceUseDynSharedMemory, Integer(1)); - } + (*device_mod_)->Add(kernel_symbol_global, device_func); // generate calls to the device function From de92607774f60ff270b04b07e36ab21a8e8b3a56 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 22 Mar 2023 10:27:11 -0500 Subject: [PATCH 12/13] [TIR][Utils] Implemented ConvertSSA as IRModule transform When passes create new PrimFuncs, such as when `tir.SplitHostDevice` separates out a `tir::Stmt` into an independent function, the parameters of these new function may alias existing variable definitions. While this is well-defined, because variable definitions are not shared across function boundaries, it can give false discrepancies from `tvm.ir.assert_structural_equal`. This commit implements `tvm::tir::transform::ConvertSSA`, which ensures unique variable declaration locations across an entire module. --- include/tvm/tir/transform.h | 13 +++ src/tir/transforms/ir_utils.cc | 185 +++++++++++++++++++++++++++++---- 2 files changed, 178 insertions(+), 20 deletions(-) diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index d4f537ff3169..35aa392db2de 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -176,6 +176,19 @@ TVM_DLL Pass RewriteUnsafeSelect(); */ TVM_DLL Pass Simplify(); +/*! + * \brief Convert an IRModule to be SSA form. + * + * This pass handles cases where the same tir::Var appears in + * multiple functions within the same module. For example, after + * extracting a fragment from one function into another, where the + * same `tir::Var` may be defined both as within the body of the + * original function, and as a parameter within the hoisted function. + * + * \return The pass. + */ +TVM_DLL Pass ConvertSSA(); + /*! * \brief Instruments bound checkers. * diff --git a/src/tir/transforms/ir_utils.cc b/src/tir/transforms/ir_utils.cc index f6e4ac45c612..3d9f19c153e4 100644 --- a/src/tir/transforms/ir_utils.cc +++ b/src/tir/transforms/ir_utils.cc @@ -26,6 +26,7 @@ #include #include #include +#include #include #include @@ -90,13 +91,106 @@ Stmt MergeNest(const std::vector>& nest, Stmt body) { class IRConvertSSA final : public StmtExprMutator { public: - PrimExpr VisitExpr_(const VarNode* op) final { - if (scope_.count(op) && !scope_[op].empty()) { - return scope_[op].back(); - } else { - return GetRef(op); + PrimFunc VisitPrimFunc(PrimFunc func) { + std::vector redefines; + + // Remap parameters, if they were used in another function + auto params = func->params.Map([&](const tir::Var& var) -> tir::Var { + if (defined_.count(var.get())) { + const ScopedRedefine& redefine = redefines.emplace_back(this, var); + return redefine.new_var; + } else { + defined_.insert(var.get()); + return var; + } + }); + + // Remap implicitly defined buffer parameters + { + std::unordered_set defined_params; + for (const auto& var : func->params) { + defined_params.insert(var.get()); + } + for (const auto& [var, buffer] : func->buffer_map) { + auto check_expr = [&](const PrimExpr& expr) { + auto* var_ptr = expr.as(); + if (!var_ptr) return; + if (defined_params.count(var_ptr)) return; + + if (defined_.count(var_ptr)) { + auto var = GetRef(var_ptr); + redefines.emplace_back(this, var); + } else { + defined_.insert(var_ptr); + } + }; + for (const auto& dim : buffer->shape) { + check_expr(dim); + } + for (const auto& stride : buffer->strides) { + check_expr(stride); + } + check_expr(buffer->elem_offset); + } + } + + // Update the buffer map, based on the redefined parameters + auto buffer_map = [&]() { + Map buffer_map; + bool made_change = false; + for (const auto& [var, buffer] : func->buffer_map) { + auto new_var = GetRemappedVar(var); + auto new_buf = GetRemappedBuffer(buffer); + + made_change = made_change || !var.same_as(new_var) || !buffer.same_as(new_buf); + buffer_map.Set(new_var, new_buf); + } + if (made_change) { + return buffer_map; + } else { + return func->buffer_map; + } + }(); + + auto attrs = [&]() -> DictAttrs { + Map dict; + bool made_change = false; + + for (const auto& [key, old_value] : func->attrs->dict) { + auto value = old_value; + if (auto* expr = value.as()) { + value = VisitExpr(GetRef(expr)); + } else if (auto* stmt = value.as()) { + value = VisitStmt(GetRef(stmt)); + } + + made_change = made_change || !value.same_as(old_value); + dict.Set(key, value); + } + + if (made_change) { + return DictAttrs(dict); + } else { + return func->attrs; + } + }(); + + auto body = VisitStmt(func->body); + + // If anything changed, update the returned function + if (!params.same_as(func->params) || !buffer_map.same_as(func->buffer_map) || + !attrs.same_as(func->attrs) || !body.same_as(func->body)) { + func = PrimFunc(params, body, func->ret_type, buffer_map, attrs); + } + + // Pop the redefines in reverse order of creation + while (redefines.size()) { + redefines.pop_back(); } + return func; } + + PrimExpr VisitExpr_(const VarNode* op) final { return GetRemappedVar(GetRef(op)); } PrimExpr VisitExpr_(const LetNode* op) final { const Var& v = op->var; if (defined_.count(v.get())) { @@ -142,18 +236,27 @@ class IRConvertSSA final : public StmtExprMutator { return node; } + Var GetRemappedVar(Var var) { + if (auto it = scope_.find(var.get()); it != scope_.end() && it->second.size()) { + return it->second.back(); + } else { + return var; + } + } + Buffer GetRemappedBuffer(Buffer buf) { // Determine the buffer var that should be in the updated buffer, // given the current scope. If no redefines are present, then the // buffer var is unchanged. - Var new_buffer_var = buf->data; - auto var_it = scope_.find(buf->data.get()); - if (var_it != scope_.end() && !var_it->second.empty()) { - new_buffer_var = var_it->second.back(); - } + Var new_buffer_var = GetRemappedVar(buf->data); + PrimExpr elem_offset = VisitExpr(buf->elem_offset); + auto visit_expr = [this](const PrimExpr& expr) { return VisitExpr(expr); }; + Array shape = buf->shape.Map(visit_expr); + Array strides = buf->strides.Map(visit_expr); // If no mapping is required, return the original buffer. - if (new_buffer_var.same_as(buf->data)) { + if (new_buffer_var.same_as(buf->data) && elem_offset.same_as(buf->elem_offset) && + shape.same_as(buf->shape) && strides.same_as(buf->strides)) { return buf; } @@ -169,9 +272,9 @@ class IRConvertSSA final : public StmtExprMutator { // new buffer, pushing it onto the scoped stack of existing // buffers. This will be popped when the new_buffer_var // redefinition is popped. - Buffer new_buf(new_buffer_var, buf->dtype, buf->shape, buf->strides, buf->elem_offset, - buf->name, buf->data_alignment, buf->offset_factor, buf->buffer_type, - buf->axis_separators, buf->span); + Buffer new_buf(new_buffer_var, buf->dtype, shape, strides, elem_offset, buf->name, + buf->data_alignment, buf->offset_factor, buf->buffer_type, buf->axis_separators, + buf->span); buffers.push_back(new_buf); return new_buf; } @@ -239,16 +342,33 @@ class IRConvertSSA final : public StmtExprMutator { } ~ScopedRedefine() { - parent->scope_[old_var.get()].pop_back(); - for (auto& kv : parent->buf_remap_) { - std::vector& buffers = kv.second; - if (buffers.size() && (buffers.back()->data.get() == new_var.get())) { - buffers.pop_back(); + if (parent) { + parent->scope_[old_var.get()].pop_back(); + for (auto& kv : parent->buf_remap_) { + std::vector& buffers = kv.second; + if (buffers.size() && (buffers.back()->data.get() == new_var.get())) { + buffers.pop_back(); + } } } } - IRConvertSSA* parent; + ScopedRedefine& operator=(const ScopedRedefine&) = delete; + ScopedRedefine(const ScopedRedefine&) = delete; + + ScopedRedefine& operator=(ScopedRedefine&& other) { + swap(other); + return *this; + } + ScopedRedefine(ScopedRedefine&& other) { swap(other); } + + void swap(ScopedRedefine& other) { + std::swap(parent, other.parent); + std::swap(old_var, other.old_var); + std::swap(new_var, other.new_var); + } + + IRConvertSSA* parent{nullptr}; Var old_var; Var new_var; }; @@ -447,5 +567,30 @@ std::pair GetAsyncWaitAttributes(const AttrStmtNode* op) { return std::make_pair(op->value, inner->value); } +namespace transform { +Pass ConvertSSA() { + auto pass_func = [](IRModule mod, PassContext ctx) { + tir::IRConvertSSA converter; + Map functions; + bool made_change = false; + for (auto [gvar, base_func] : mod->functions) { + if (auto* ptr = base_func.as()) { + auto updated = converter.VisitPrimFunc(GetRef(ptr)); + if (!updated.same_as(base_func)) { + made_change = true; + base_func = updated; + } + } + functions.Set(gvar, base_func); + } + if (made_change) { + mod.CopyOnWrite()->functions = std::move(functions); + } + return mod; + }; + return tvm::transform::CreateModulePass(pass_func, 0, "tir.ConvertSSA", {}); +} + +} // namespace transform } // namespace tir } // namespace tvm From 0d81a4e6b927513a2e46d223fa2db22cddf89f99 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 22 Mar 2023 10:31:07 -0500 Subject: [PATCH 13/13] [TIR] Update SplitHostDevice to post-process with ConvertSSA Avoid duplicate variable defitions between the host and device PrimFunc. --- src/tir/transforms/split_host_device.cc | 2 +- .../test_tir_transform_split_host_device.py | 25 +++++++++++++++++++ 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/src/tir/transforms/split_host_device.cc b/src/tir/transforms/split_host_device.cc index 3696ff84e5b8..4f47b8ce2bf9 100644 --- a/src/tir/transforms/split_host_device.cc +++ b/src/tir/transforms/split_host_device.cc @@ -282,7 +282,7 @@ Pass SplitHostDevice() { } } mod->Update(device_mod); - return mod; + return ConvertSSA()(mod); }; return tvm::transform::CreateModulePass(pass_func, 0, "tir.SplitHostDevice", {}); diff --git a/tests/python/unittest/test_tir_transform_split_host_device.py b/tests/python/unittest/test_tir_transform_split_host_device.py index f4adac9cf742..680f23e07a17 100644 --- a/tests/python/unittest/test_tir_transform_split_host_device.py +++ b/tests/python/unittest/test_tir_transform_split_host_device.py @@ -17,6 +17,7 @@ import tvm from tvm import te import tvm.testing +from tvm.script import tir as T, ir as I @tvm.testing.requires_cuda @@ -48,5 +49,29 @@ def test_split_host_device_func_attr(): assert fdevice.attrs["tir.is_global_func"].value +def test_ssa_across_entire_module(): + """The host and device functions should not share TIR vars + + Any arguments that are passed from the host to the device should + be in terms of independent TIR variables. + """ + + @I.ir_module + class before: + @T.prim_func + def main(): + T.func_attr({"global_symbol": "main", "target": T.target("cuda")}) + for i in range(16): + T.attr(0, "device_scope", 0) + for j in range(16): + T.evaluate(i) + + after = tvm.tir.transform.SplitHostDevice()(before) + loop_var = after["main"].body.loop_var + param_var = after["main_kernel0"].params[0] + + assert not loop_var.same_as(param_var) + + if __name__ == "__main__": test_split_host_device_func_attr()