diff --git a/include/tvm/node/structural_equal.h b/include/tvm/node/structural_equal.h index 5bd76404a998..cff1e775072f 100644 --- a/include/tvm/node/structural_equal.h +++ b/include/tvm/node/structural_equal.h @@ -195,20 +195,40 @@ class SEqualReducer { * \param rhs The right operand. * \return the immediate check result. */ - bool operator()(const double& lhs, const double& rhs) const; - bool operator()(const int64_t& lhs, const int64_t& rhs) const; - bool operator()(const uint64_t& lhs, const uint64_t& rhs) const; - bool operator()(const int& lhs, const int& rhs) const; - bool operator()(const bool& lhs, const bool& rhs) const; - bool operator()(const std::string& lhs, const std::string& rhs) const; - bool operator()(const DataType& lhs, const DataType& rhs) const; + bool operator()(const double& lhs, const double& rhs, + Optional paths = NullOpt) const; + bool operator()(const int64_t& lhs, const int64_t& rhs, + Optional paths = NullOpt) const; + bool operator()(const uint64_t& lhs, const uint64_t& rhs, + Optional paths = NullOpt) const; + bool operator()(const int& lhs, const int& rhs, Optional paths = NullOpt) const; + bool operator()(const bool& lhs, const bool& rhs, Optional paths = NullOpt) const; + bool operator()(const std::string& lhs, const std::string& rhs, + Optional paths = NullOpt) const; + bool operator()(const DataType& lhs, const DataType& rhs, + Optional paths = NullOpt) const; template ::value>::type> - bool operator()(const ENum& lhs, const ENum& rhs) const { + bool operator()(const ENum& lhs, const ENum& rhs, + Optional paths = NullOpt) const { using Underlying = typename std::underlying_type::type; static_assert(std::is_same::value, "Enum must have `int` as the underlying type"); - return EnumAttrsEqual(static_cast(lhs), static_cast(rhs), &lhs, &rhs); + return EnumAttrsEqual(static_cast(lhs), static_cast(rhs), &lhs, &rhs, paths); + } + + template , ObjectPath>>> + bool operator()(const T& lhs, const T& rhs, const Callable& callable) { + if (IsPathTracingEnabled()) { + ObjectPathPair current_paths = GetCurrentObjectPaths(); + ObjectPathPair new_paths = {callable(current_paths->lhs_path), + callable(current_paths->rhs_path)}; + return (*this)(lhs, rhs, new_paths); + } else { + return (*this)(lhs, rhs); + } } /*! @@ -310,7 +330,8 @@ class SEqualReducer { void RecordMismatchPaths(const ObjectPathPair& paths) const; private: - bool EnumAttrsEqual(int lhs, int rhs, const void* lhs_address, const void* rhs_address) const; + bool EnumAttrsEqual(int lhs, int rhs, const void* lhs_address, const void* rhs_address, + Optional paths = NullOpt) const; bool ObjectAttrsEqual(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars, const ObjectPathPair* paths) const; @@ -321,7 +342,8 @@ class SEqualReducer { template static bool CompareAttributeValues(const T& lhs, const T& rhs, - const PathTracingData* tracing_data); + const PathTracingData* tracing_data, + Optional paths = NullOpt); /*! \brief Internal class pointer. */ Handler* handler_ = nullptr; diff --git a/include/tvm/script/ir_builder/base.h b/include/tvm/script/ir_builder/base.h index 61ca3eb9f7eb..a00ea5768e23 100644 --- a/include/tvm/script/ir_builder/base.h +++ b/include/tvm/script/ir_builder/base.h @@ -237,6 +237,8 @@ class IRBuilder : public runtime::ObjectRef { * \sa tvm::support::With */ static IRBuilder Current(); + /*! \brief See if the current thread-local scope has an IRBuilder. */ + static bool IsInScope(); /*! * \brief Give a string name to the `obj` * \tparam TObjectRef The type of the object to name. diff --git a/include/tvm/script/ir_builder/ir/frame.h b/include/tvm/script/ir_builder/ir/frame.h index 887981ccffc8..ed425cf61441 100644 --- a/include/tvm/script/ir_builder/ir/frame.h +++ b/include/tvm/script/ir_builder/ir/frame.h @@ -38,13 +38,21 @@ namespace ir { */ class IRModuleFrameNode : public IRBuilderFrameNode { public: - Array global_vars; - Array functions; + /*! \brief A map from string names to global variables that ensures global uniqueness. */ + Map global_var_map; + /*! + * \brief A map from GlobalVar to all global functions. + * \note Only defined functions are in the map, while declared functions are not included. + */ + Map functions; + /*! \brief IRModule's attributes. */ + Map attrs; void VisitAttrs(tvm::AttrVisitor* v) { IRBuilderFrameNode::VisitAttrs(v); - v->Visit("global_vars", &global_vars); + v->Visit("global_vars", &global_var_map); v->Visit("functions", &functions); + v->Visit("attrs", &attrs); } static constexpr const char* _type_key = "script.ir_builder.IRModuleFrame"; diff --git a/include/tvm/script/ir_builder/ir/ir.h b/include/tvm/script/ir_builder/ir/ir.h index f0e7cc6f5c2f..49bdcf60e6fb 100644 --- a/include/tvm/script/ir_builder/ir/ir.h +++ b/include/tvm/script/ir_builder/ir/ir.h @@ -37,6 +37,23 @@ namespace ir { */ TVM_DLL IRModuleFrame IRModule(); +/*! + * \brief Declare a Function without given the specific function implementation. + * \note It is usually used in cross-function call. And we can specify the function by `DefFunction` + * \param func_name The function unique name. + * \param func_signature A Function w/o body, which used to specify the function signature + * (i.e. func params and func return type/shape). + * \return The corresponding GlobalVar. + */ +TVM_DLL GlobalVar DeclFunction(const String& func_name, const BaseFunc& func_signature); + +/*! + * \brief Define the function which is declared before. + * \param func_name The function unique name. + * \param func The given function implementation + */ +TVM_DLL void DefFunction(const String& func_name, const BaseFunc& func); + } // namespace ir } // namespace ir_builder } // namespace script diff --git a/include/tvm/script/ir_builder/tir/ir.h b/include/tvm/script/ir_builder/tir/ir.h index a0343b03955b..73219f0302d9 100644 --- a/include/tvm/script/ir_builder/tir/ir.h +++ b/include/tvm/script/ir_builder/tir/ir.h @@ -435,9 +435,9 @@ void Evaluate(PrimExpr value); */ inline Var Handle(runtime::DataType dtype = runtime::DataType::Void(), // String storage_scope = "global", // - bool is_size_var = false) { + bool is_size_var = false, bool is_unknown_type = false) { Type type_annotation{nullptr}; - if (dtype.is_void() && storage_scope == "global") { + if (is_unknown_type && storage_scope == "global") { type_annotation = PrimType(runtime::DataType::Handle()); } else { type_annotation = PointerType(PrimType(dtype), storage_scope); diff --git a/include/tvm/tir/function.h b/include/tvm/tir/function.h index 48328263fb55..2cb1269d010f 100644 --- a/include/tvm/tir/function.h +++ b/include/tvm/tir/function.h @@ -272,10 +272,11 @@ PrimFunc Specialize(PrimFunc func, const Map& param_map); * \sa tvm::attr */ namespace attr { + /*! * \brief List of thread IterVar that a DeviceLaunch function corresponds to. * - * Type: Array + * Type: Array * * We call a device kernel launch function f using the following convention: * @@ -283,23 +284,42 @@ namespace attr { * [arg1, arg2, ..., arg_n, * work_size_1, work_size_2, ... work_size_m, dyn_shmem_size]) * - * Here n = len(arg), m = len(work_size) = len(device_thread_axis). + * Here n = len(arg), m = len(work_size) = len(launch_params)-1. * - * When kDeviceUseDynSharedMemory is not set, dyn_shmem_size argument is omitted. + * The list of kernel launch params indicates which additional + * parameters will be provided to the PackedFunc by the calling + * scope. * - * The list of device_thread_axis indicates how can be bind the - * work_size arguments to the corresponding threads. + * - "threadIdx.x", "threadIdx.y", "threadIdx.z" * - * \sa tvm::CallingConv::kDeviceKernelLaunch - */ -constexpr const char* kDeviceThreadAxis = "tir.device_thread_axis"; - -/*! - * \brief Whether or not use dynamic shared memory. + * The extent of the thread count in x/y/z, to be used when + * launching the compute kernel on the device. For example, the + * gridDimX/Y/Z parameters passed to cuLaunchKernel when launching a + * CUDA kernel, or the groupCountX/Y/Z parameters passed to + * vkCmdDispatch when dispatching a compute pipeline to Vulkan. * - * Type: Integer + * - "blockIdx.x", "blockIdx.y", "blockIdx.z" + * + * The extent of the block iterators, to be used when launching the + * compute kernel on the device. For example, the blockDimX/Y/Z + * parameters passed to cuLaunchKernel when launching a CUDA kernel. + * For runtimes that do not require the block to be provided + * externally, this parameter is ignored. For example, the + * spv::ExecutionModeLocalSize for SPIR-V shaders on Vulkan, where + * this parameter is defined in the shader. + * + * - tvm::runtime::launch_param::kUseDynamicSharedMemoryTag + * + * The size of the shared memory that may be allocated internally by + * the kernel. For example, exposed as the + * CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES attribute in + * cuda. + * + * Defined as "tir.use_dyn_shared_memory". + * + * \sa tvm::CallingConv::kDeviceKernelLaunch */ -constexpr const char* kDeviceUseDynSharedMemory = "tir.device_use_dyn_shared_memory"; +constexpr const char* kKernelLaunchParams = "tir.kernel_launch_params"; /*! * \brief Whether to set noalias rule on the function arguments. diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index 9ed9973871d9..51c99fb375f3 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -734,19 +734,33 @@ class SeqStmt : public Stmt { public: explicit Flattener(Array* seq) : seq_(seq) {} - void operator()(size_t i, const Stmt& stmt) const { - if (!stmt.defined()) return; - if (auto* op = stmt.as()) { - operator()(0, op->seq); - } else { - seq_->push_back(stmt); + template + void operator()(size_t i, const T& stmt_or_seq) const { + if constexpr (std::is_base_of_v) { + // Early bail-out, applicable to any ObjectRef + if (!stmt_or_seq.defined()) return; } - } - template - void operator()(size_t i, const T& seq) const { - for (auto v : seq) { - this->operator()(0, v); + if constexpr (std::is_same_v) { + // No need for dynamic type-checking if the static type is a + // SeqStmt. + (*this)(0, stmt_or_seq->seq); + } else if constexpr (std::is_base_of_v) { + // Dynamic type-checking for a SeqStmt that could be + // flattened. + if (auto* op = stmt_or_seq.template as()) { + operator()(0, op->seq); + } else { + seq_->push_back(stmt_or_seq); + } + } else if constexpr (std::is_base_of_v) { + // Any other Stmt type just gets appended. + seq_->push_back(stmt_or_seq); + } else { + // Anything else is treated as an iterable of Stmt. + for (auto v : stmt_or_seq) { + this->operator()(0, v); + } } } diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index d4f537ff3169..35aa392db2de 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -176,6 +176,19 @@ TVM_DLL Pass RewriteUnsafeSelect(); */ TVM_DLL Pass Simplify(); +/*! + * \brief Convert an IRModule to be SSA form. + * + * This pass handles cases where the same tir::Var appears in + * multiple functions within the same module. For example, after + * extracting a fragment from one function into another, where the + * same `tir::Var` may be defined both as within the body of the + * original function, and as a parameter within the hoisted function. + * + * \return The pass. + */ +TVM_DLL Pass ConvertSSA(); + /*! * \brief Instruments bound checkers. * diff --git a/python/tvm/ir/module.py b/python/tvm/ir/module.py index 3daffb2640c5..232c70aa93d8 100644 --- a/python/tvm/ir/module.py +++ b/python/tvm/ir/module.py @@ -37,7 +37,7 @@ class IRModule(Node, Scriptable): Map of global var to BaseFunc """ - def __init__(self, functions=None, type_definitions=None): + def __init__(self, functions=None, type_definitions=None, attrs=None): if functions is None: functions = {} elif isinstance(functions, dict): @@ -60,7 +60,17 @@ def __init__(self, functions=None, type_definitions=None): raise TypeError("Expect type_definitions to be Dict[GlobalTypeVar, Type]") mapped_type_defs[k] = v type_definitions = mapped_type_defs - self.__init_handle_by_constructor__(_ffi_api.IRModule, functions, type_definitions) + + attrs = None if not attrs else attrs + if attrs is not None: + attrs = ast.literal_eval(str(attrs)) + attrs = tvm.ir.make_node("DictAttrs", **attrs) + self.__init_handle_by_constructor__( + _ffi_api.IRModule, + functions, + type_definitions, + attrs, + ) def __setitem__(self, var, val): """Add a mapping to the module. diff --git a/python/tvm/script/ir_builder/base.py b/python/tvm/script/ir_builder/base.py index 7aa33ee49c72..1d5d050444f7 100644 --- a/python/tvm/script/ir_builder/base.py +++ b/python/tvm/script/ir_builder/base.py @@ -64,8 +64,10 @@ def __enter__(self) -> "IRBuilderFrame": _ffi_api.IRBuilderFrameEnter(self) # type: ignore[attr-defined] # pylint: disable=no-member return self - def __exit__(self, ptype, value, trace) -> None: # pylint: disable=unused-argument - _ffi_api.IRBuilderFrameExit(self) # type: ignore[attr-defined] # pylint: disable=no-member + def __exit__(self, exc_type, exc_value, trace) -> None: # pylint: disable=unused-argument + if exc_type is None and exc_value is None: + # Do not execute `FrameExit` if the with scope exits because of exceptions + _ffi_api.IRBuilderFrameExit(self) # type: ignore[attr-defined] # pylint: disable=no-member def add_callback(self, callback: Callable[[], None]) -> None: """Add a callback method invoked when exiting the with-scope. @@ -136,6 +138,17 @@ def current() -> "IRBuilder": """ return _ffi_api.IRBuilderCurrent() # type: ignore[attr-defined] # pylint: disable=no-member + @staticmethod + def is_in_scope() -> bool: + """See if the current thread-local scope has an IRBuilder. + + Returns + ------- + bool + Whether the current thread-local scope has an IRBuilder + """ + return _ffi_api.IRBuilderIsInScope() # type: ignore[attr-defined] # pylint: disable=no-member + def get(self) -> _Object: """Get the constructed IR.""" return _ffi_api.IRBuilderGet(self) # type: ignore[attr-defined] # pylint: disable=no-member diff --git a/python/tvm/script/ir_builder/ir/__init__.py b/python/tvm/script/ir_builder/ir/__init__.py index ebb9728737ad..b796de8113f3 100644 --- a/python/tvm/script/ir_builder/ir/__init__.py +++ b/python/tvm/script/ir_builder/ir/__init__.py @@ -16,4 +16,9 @@ # under the License. """Package tvm.script.ir_builder.ir""" from .frame import IRModuleFrame -from .ir import ir_module +from .ir import ( + decl_function, + def_function, + ir_module, + module_attrs, +) diff --git a/python/tvm/script/ir_builder/ir/ir.py b/python/tvm/script/ir_builder/ir/ir.py index 213180463cb2..eabbd188d063 100644 --- a/python/tvm/script/ir_builder/ir/ir.py +++ b/python/tvm/script/ir_builder/ir/ir.py @@ -16,9 +16,68 @@ # under the License. """Package tvm.script.ir_builder.ir.ir""" +from typing import Dict, List + +from tvm.runtime import Object as tvm_Object + +from tvm.ir import BaseFunc, GlobalVar + from . import _ffi_api from .frame import IRModuleFrame def ir_module() -> IRModuleFrame: + """Start a ir_module frame. + Returns + ------- + frame: IRModuleFrame + The constructed frame. + """ return _ffi_api.IRModule() # type: ignore[attr-defined] # pylint: disable=no-member + + +def decl_function(func_name: str, func_signature: BaseFunc) -> GlobalVar: + """Declare a Function without given the specific function implementation. + Parameters + ---------- + func_name : str + The function unique name. + + func_signature: Optional[BaseFunc] + A Function w/o body, which used to specify the function signature + (i.e. func params and func return type/shape). + + Note + ---- + It is usually used in cross-function call. And we can specify the function by `DefFunction` + Returns + ------- + gv : GlobalVar + The corresponding GlobalVar. + """ + + return _ffi_api.DeclFunction( # type: ignore[attr-defined] # pylint: disable=no-member + func_name, func_signature + ) + + +def def_function(func_name: str, func: BaseFunc) -> None: + """Define the function which is declared before. + Parameters + ---------- + func_name : str + The function unique name. + func: BaseFunc + The given function implementation + """ + return _ffi_api.DefFunction(func_name, func) # type: ignore[attr-defined] # pylint: disable=no-member + + +def module_attrs(attrs: Dict[str, tvm_Object]) -> None: + """Specify the attrs of the ir_module frame. + Parameters + ---------- + attrs: Dict[str, Object] + The module attrs. + """ + return _ffi_api.ModuleAttrs(attrs) # type: ignore[attr-defined] # pylint: disable=no-member diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index c3ced1e0338b..a10f4bc6746a 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -1441,7 +1441,9 @@ def boolean(expr: Optional[PrimExpr] = None, is_size_var: bool = False) -> PrimE return _ffi_api.Boolean(expr, is_size_var) # type: ignore[attr-defined] # pylint: disable=no-member -def handle(dtype: str = "void", storage_scope: str = "global", *, is_size_var: bool = False) -> Var: +def handle( + dtype: Optional[str] = None, storage_scope: str = "global", *, is_size_var: bool = False +) -> Var: """Create a TIR var that represents a pointer. Parameters @@ -1460,7 +1462,10 @@ def handle(dtype: str = "void", storage_scope: str = "global", *, is_size_var: b res : PrimExpr The new tir.Var with type handle or casted expression with type handle. """ - return _ffi_api.Handle(dtype, storage_scope, is_size_var) # type: ignore[attr-defined] # pylint: disable=no-member + is_unknown_type = dtype is None + if dtype is None: + dtype = "void" + return _ffi_api.Handle(dtype, storage_scope, is_size_var, is_unknown_type) # type: ignore[attr-defined] # pylint: disable=no-member def void(expr: Optional[PrimExpr] = None, *, is_size_var: bool = False) -> PrimExpr: diff --git a/python/tvm/script/parser/core/diagnostics.py b/python/tvm/script/parser/core/diagnostics.py index ad7ae5034780..2767a97f6096 100644 --- a/python/tvm/script/parser/core/diagnostics.py +++ b/python/tvm/script/parser/core/diagnostics.py @@ -220,7 +220,7 @@ def _emit(self, node: doc.AST, message: str, level: diagnostics.DiagnosticLevel) level : diagnostics.DiagnosticLevel The diagnostic level. """ - lineno = node.lineno or self.source.start_line + lineno = node.lineno or 1 col_offset = node.col_offset or self.source.start_column end_lineno = node.end_lineno or lineno end_col_offset = node.end_col_offset or col_offset diff --git a/python/tvm/script/parser/core/entry.py b/python/tvm/script/parser/core/entry.py index 9e6c100c954d..9a2bfdfe486d 100644 --- a/python/tvm/script/parser/core/entry.py +++ b/python/tvm/script/parser/core/entry.py @@ -51,6 +51,7 @@ def parse(program: Union[doc.AST, Any, str], extra_vars: Dict[str, Any] = None) "ir": ir, "T": tir, "tir": tir, + "tvm": tvm, } source = Source(program) diff --git a/python/tvm/script/parser/core/evaluator.py b/python/tvm/script/parser/core/evaluator.py index 3a72a3c33106..075aedd89146 100644 --- a/python/tvm/script/parser/core/evaluator.py +++ b/python/tvm/script/parser/core/evaluator.py @@ -203,7 +203,7 @@ def _visit(self, node: doc.AST) -> Any: else: value = self._eval_expr(node.__class__(**fields)) except Exception as e: # pylint: disable=broad-except,invalid-name - self.parser.report_error(node, str(e)) + self.parser.report_error(node, e) return self._add_intermediate_result(value) def _eval_lambda(self, node: doc.Lambda) -> Any: diff --git a/python/tvm/script/parser/core/parser.py b/python/tvm/script/parser/core/parser.py index 7c699c42aecb..105164ed5ffc 100644 --- a/python/tvm/script/parser/core/parser.py +++ b/python/tvm/script/parser/core/parser.py @@ -60,6 +60,10 @@ def context(): return context() +def _do_nothing(*args, **kwargs): # pylint: disable=unused-argument + pass + + class VarTableFrame: """The variable table frame. A frame of variable table stores the variables created in one block or scope. @@ -259,6 +263,17 @@ def parse(self, extra_vars: Optional[Dict[str, Any]] = None) -> Any: node = self.diag.source.as_ast() self.visit(node) + def get_dispatch_token(self, node: doc.FunctionDef) -> str: + if not isinstance(node, doc.FunctionDef): + self.report_error(node, "Only can get dispatch token for function.") + if not node.decorator_list: + self.report_error(node, "Function must be decorated") + # TODO: only the last decorator is parsed + decorator = self.eval_expr(node.decorator_list[-1]) + if not hasattr(decorator, "dispatch_token"): + self.report_error(node, "The parser does not understand the decorator") + return decorator.dispatch_token + def with_dispatch_token(self, token: str): """Add a new dispatching token as with statement. @@ -388,6 +403,8 @@ def report_error( # Only take the last line of the error message if isinstance(err, TVMError): msg = list(filter(None, str(err).split("\n")))[-1] + elif isinstance(err, KeyError): + msg = "KeyError: " + str(err) else: msg = str(err) self.diag.error(node, msg) @@ -457,30 +474,33 @@ def visit_tvm_annotation(self, node: doc.expr) -> Any: """ return _dispatch(self, "tvm_annotation")(self, node) - def visit_FunctionDef(self, node: doc.FunctionDef) -> Any: # pylint: disable=invalid-name - """The general function definition visiting method. + def visit_FunctionDef(self, node: doc.FunctionDef) -> None: # pylint: disable=invalid-name + """The general function definition visit method. Parameters ---------- node : doc.FunctionDef - The doc AST function definition node. - - Returns - ------- - res : Any - The visiting result. + The doc FunctionDef node. """ - if not node.decorator_list: - self.report_error(node, "Function must be decorated") - # TODO: only the last decorator is parsed - decorator = self.eval_expr(node.decorator_list[-1]) - if not hasattr(decorator, "dispatch_token"): - self.report_error(node, "The parser does not understand the decorator") - token = decorator.dispatch_token + token = self.get_dispatch_token(node) + current_token = self.dispatch_tokens[-1] func = dispatch.get(token=token, type_name="FunctionDef", default=None) if func is None: self.report_error(node, "The parser does not understand the decorator") + pre_func = dispatch.get( + token=current_token, type_name="pre_token_switch", default=_do_nothing + ) + post_func = dispatch.get( + token=current_token, type_name="post_token_switch", default=_do_nothing + ) + pre_func(self, node) _dispatch_wrapper(func)(self, node) + post_func(self, node) + + def visit_tvm_declare_function(self, node: doc.FunctionDef) -> None: + token = self.get_dispatch_token(node) + with self.with_dispatch_token(token): + _dispatch(self, "tvm_declare_function")(self, node) def visit_ClassDef(self, node: doc.ClassDef) -> Any: # pylint: disable=invalid-name """The general class definition visiting method. diff --git a/python/tvm/script/parser/ir/__init__.py b/python/tvm/script/parser/ir/__init__.py index fedd2f0a14a8..f8c9d4f0afc9 100644 --- a/python/tvm/script/parser/ir/__init__.py +++ b/python/tvm/script/parser/ir/__init__.py @@ -15,8 +15,8 @@ # specific language governing permissions and limitations # under the License. """The ir module parser""" - +from ...ir_builder.ir import * # pylint: disable=redefined-builtin from . import parser as _parser from .entry import ir_module -__all__ = ["ir_module"] +__all__ = ["ir_module", "module_attrs", "module_global_infos", "dummy_global_info"] diff --git a/python/tvm/script/parser/ir/parser.py b/python/tvm/script/parser/ir/parser.py index e0268412d284..201c99074f20 100644 --- a/python/tvm/script/parser/ir/parser.py +++ b/python/tvm/script/parser/ir/parser.py @@ -32,10 +32,20 @@ def _visit_class_def(self: Parser, node: doc.ClassDef) -> None: node : doc.ClassDef The doc AST class definition node. """ + with self.var_table.with_frame(): with I.ir_module(): with self.with_dispatch_token("ir"): - self.visit_body(node.body) + for stmt in node.body: + if not isinstance(stmt, doc.FunctionDef): + self.visit(stmt) + for stmt in node.body: + if isinstance(stmt, doc.FunctionDef): + self.visit_tvm_declare_function(stmt) + with self.with_dispatch_token("ir"): + for stmt in node.body: + if isinstance(stmt, doc.FunctionDef): + self.visit(stmt) @dispatch.register(token="ir", type_name="Assign") @@ -53,7 +63,7 @@ def _visit_assign(_self: Parser, _node: doc.Assign) -> None: @dispatch.register(token="ir", type_name="Expr") -def _visit_expr(_self: Parser, _node: doc.Expr) -> None: +def _visit_expr(self: Parser, node: doc.Expr) -> None: """The expression visiting method for ir module. Parameters @@ -64,6 +74,7 @@ def _visit_expr(_self: Parser, _node: doc.Expr) -> None: node : doc.ClassDef The doc AST expression node. """ + self.eval_expr(node.value) @dispatch.register(token="default", type_name="Assign") diff --git a/python/tvm/script/parser/tir/entry.py b/python/tvm/script/parser/tir/entry.py index 411a7f8f3c83..649f817411f0 100644 --- a/python/tvm/script/parser/tir/entry.py +++ b/python/tvm/script/parser/tir/entry.py @@ -83,7 +83,7 @@ def __getitem__(self, keys) -> Buffer: return self(keys) if len(keys) >= 2 and not isinstance(keys[1], str): return self(keys) - return self(*keys) # pylint: disable=no-member # type: ignore + return self(*keys) # type: ignore[attr-defined] # pylint: disable=no-member class PtrProxy: @@ -93,7 +93,7 @@ class PtrProxy: def __call__(self, dtype, storage_scope="global"): if callable(dtype): dtype = dtype().dtype - return ptr(dtype, storage_scope) # pylint: disable=no-member # type: ignore + return ptr(dtype, storage_scope) # type: ignore[attr-defined] # pylint: disable=no-member @deprecated("T.Ptr[...]", "T.handle(...)") def __getitem__(self, keys): diff --git a/python/tvm/script/parser/tir/parser.py b/python/tvm/script/parser/tir/parser.py index 8a067267a352..63171f672289 100644 --- a/python/tvm/script/parser/tir/parser.py +++ b/python/tvm/script/parser/tir/parser.py @@ -24,6 +24,7 @@ from tvm.ir import PrimType from tvm.tir import Buffer, IterVar, PrimExpr, Var +from ...ir_builder import ir as I from ...ir_builder import tir as T from ...ir_builder.base import IRBuilder from ...ir_builder.base import IRBuilderFrame as Frame @@ -473,3 +474,28 @@ def visit_return(self: Parser, node: doc.Return) -> None: The doc AST return node. """ self.report_error(node, "Return is not allowed.") + + +@dispatch.register(token="tir", type_name="tvm_declare_function") +def visit_tvm_declare_function(self: Parser, node: doc.FunctionDef) -> None: + """The function declaration step for tir + + Parameters + ---------- + self : Parser + The visiting parser. + + node : doc.Return + The doc AST return node. + """ + + ret_type = None + if node.returns is not None: + ret_type = self.eval_expr(node.returns) + if callable(ret_type): + ret_type = PrimType(ret_type().dtype) + + # Only ret_type is needed for func_signature. + func_signature = tvm.tir.PrimFunc([], None, ret_type=ret_type) + global_var = I.decl_function(node.name, func_signature) + self.var_table.add(node.name, global_var) diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py index 0fe460c085d7..419ab2275858 100644 --- a/python/tvm/tir/op.py +++ b/python/tvm/tir/op.py @@ -527,7 +527,7 @@ def tvm_struct_set(arr, index, field, value): call : PrimExpr The call expression. """ - return call_intrin("handle", "tir.tvm_struct_set", arr, index, field, value) + return call_intrin("int32", "tir.tvm_struct_set", arr, index, field, value) def address_of(buffer_load, span=None): diff --git a/src/ir/module.cc b/src/ir/module.cc index 7a973da29dfa..ba66a6689422 100644 --- a/src/ir/module.cc +++ b/src/ir/module.cc @@ -63,46 +63,46 @@ IRModule::IRModule(tvm::Map functions, } bool IRModuleNode::SEqualReduce(const IRModuleNode* other, SEqualReducer equal) const { - if (!equal(this->attrs, other->attrs)) return false; + if (!equal(this->attrs, other->attrs, [](const auto& path) { return path->Attr("attrs"); })) { + return false; + } + + if (equal.IsPathTracingEnabled()) { + if ((functions.size() != other->functions.size()) || + (type_definitions.size() != other->type_definitions.size())) { + return false; + } + } - if (functions.size() != other->functions.size()) return false; - // Update GlobalVar remap + // Define remaps for GlobalVar and GlobalTypeVar based on their + // string name. Early bail-out is only performed when path-tracing + // is disabled, as the later equality checks on the member variables + // will provide better error messages. for (const auto& gv : this->GetGlobalVars()) { - if (!other->ContainGlobalVar(gv->name_hint)) return false; - if (!equal.DefEqual(gv, other->GetGlobalVar(gv->name_hint))) return false; + if (other->ContainGlobalVar(gv->name_hint)) { + if (!equal.DefEqual(gv, other->GetGlobalVar(gv->name_hint))) return false; + } else if (!equal.IsPathTracingEnabled()) { + return false; + } } - // Checking functions - for (const auto& kv : this->functions) { - if (equal.IsPathTracingEnabled()) { - const ObjectPathPair& obj_path_pair = equal.GetCurrentObjectPaths(); - ObjectPathPair func_paths = {obj_path_pair->lhs_path->Attr("functions")->MapValue(kv.first), - obj_path_pair->rhs_path->Attr("functions") - ->MapValue(other->GetGlobalVar(kv.first->name_hint))}; - if (!equal(kv.second, other->Lookup(kv.first->name_hint), func_paths)) return false; - } else { - if (!equal(kv.second, other->Lookup(kv.first->name_hint))) return false; + for (const auto& gtv : this->GetGlobalTypeVars()) { + if (other->ContainGlobalTypeVar(gtv->name_hint)) { + if (!equal.DefEqual(gtv, other->GetGlobalTypeVar(gtv->name_hint))) return false; + } else if (!equal.IsPathTracingEnabled()) { + return false; } } - if (type_definitions.size() != other->type_definitions.size()) return false; - // Update GlobalTypeVar remap - for (const auto& gtv : this->GetGlobalTypeVars()) { - if (!other->ContainGlobalTypeVar(gtv->name_hint)) return false; - if (!equal.DefEqual(gtv, other->GetGlobalTypeVar(gtv->name_hint))) return false; + // Checking functions and type definitions + if (!equal(this->functions, other->functions, + [](const auto& path) { return path->Attr("functions"); })) { + return false; } - // Checking type_definitions - for (const auto& kv : this->type_definitions) { - if (equal.IsPathTracingEnabled()) { - const ObjectPathPair& obj_path_pair = equal.GetCurrentObjectPaths(); - ObjectPathPair type_paths = { - obj_path_pair->lhs_path->Attr("type_definitions")->MapValue(kv.first), - obj_path_pair->rhs_path->Attr("type_definitions") - ->MapValue(other->GetGlobalTypeVar(kv.first->name_hint))}; - if (!equal(kv.second, other->LookupTypeDef(kv.first->name_hint), type_paths)) return false; - } else { - if (!equal(kv.second, other->LookupTypeDef(kv.first->name_hint))) return false; - } + if (!equal(this->type_definitions, other->type_definitions, + [](const auto& path) { return path->Attr("type_definitions"); })) { + return false; } + return true; } @@ -382,10 +382,8 @@ IRModule IRModule::FromText(const String& text, const String& source_path) { TVM_REGISTER_NODE_TYPE(IRModuleNode); TVM_REGISTER_GLOBAL("ir.IRModule") - .set_body_typed([](tvm::Map funcs, - tvm::Map types) { - return IRModule(funcs, types, {}); - }); + .set_body_typed([](tvm::Map funcs, tvm::Map types, + tvm::DictAttrs attrs) { return IRModule(funcs, types, {}, {}, attrs); }); TVM_REGISTER_GLOBAL("ir.Module_Add") .set_body_typed([](IRModule mod, GlobalVar var, ObjectRef val, bool update) -> IRModule { diff --git a/src/node/structural_equal.cc b/src/node/structural_equal.cc index 42726af9859a..66a347f6b8ba 100644 --- a/src/node/structural_equal.cc +++ b/src/node/structural_equal.cc @@ -109,51 +109,72 @@ bool SEqualReducer::DefEqual(const ObjectRef& lhs, const ObjectRef& rhs) { template /* static */ bool SEqualReducer::CompareAttributeValues(const T& lhs, const T& rhs, - const PathTracingData* tracing_data) { + const PathTracingData* tracing_data, + Optional paths) { if (BaseValueEqual()(lhs, rhs)) { return true; - } else { - GetPathsFromAttrAddressesAndStoreMismatch(&lhs, &rhs, tracing_data); - return false; } + + if (tracing_data && !tracing_data->first_mismatch->defined()) { + if (paths) { + *tracing_data->first_mismatch = paths.value(); + } else { + GetPathsFromAttrAddressesAndStoreMismatch(&lhs, &rhs, tracing_data); + } + } + return false; } -bool SEqualReducer::operator()(const double& lhs, const double& rhs) const { +bool SEqualReducer::operator()(const double& lhs, const double& rhs, + Optional paths) const { return CompareAttributeValues(lhs, rhs, tracing_data_); } -bool SEqualReducer::operator()(const int64_t& lhs, const int64_t& rhs) const { +bool SEqualReducer::operator()(const int64_t& lhs, const int64_t& rhs, + Optional paths) const { return CompareAttributeValues(lhs, rhs, tracing_data_); } -bool SEqualReducer::operator()(const uint64_t& lhs, const uint64_t& rhs) const { +bool SEqualReducer::operator()(const uint64_t& lhs, const uint64_t& rhs, + Optional paths) const { return CompareAttributeValues(lhs, rhs, tracing_data_); } -bool SEqualReducer::operator()(const int& lhs, const int& rhs) const { +bool SEqualReducer::operator()(const int& lhs, const int& rhs, + Optional paths) const { return CompareAttributeValues(lhs, rhs, tracing_data_); } -bool SEqualReducer::operator()(const bool& lhs, const bool& rhs) const { +bool SEqualReducer::operator()(const bool& lhs, const bool& rhs, + Optional paths) const { return CompareAttributeValues(lhs, rhs, tracing_data_); } -bool SEqualReducer::operator()(const std::string& lhs, const std::string& rhs) const { +bool SEqualReducer::operator()(const std::string& lhs, const std::string& rhs, + Optional paths) const { return CompareAttributeValues(lhs, rhs, tracing_data_); } -bool SEqualReducer::operator()(const DataType& lhs, const DataType& rhs) const { +bool SEqualReducer::operator()(const DataType& lhs, const DataType& rhs, + Optional paths) const { return CompareAttributeValues(lhs, rhs, tracing_data_); } bool SEqualReducer::EnumAttrsEqual(int lhs, int rhs, const void* lhs_address, - const void* rhs_address) const { + const void* rhs_address, Optional paths) const { if (lhs == rhs) { return true; - } else { - GetPathsFromAttrAddressesAndStoreMismatch(lhs_address, rhs_address, tracing_data_); - return false; } + + if (tracing_data_ && !tracing_data_->first_mismatch->defined()) { + if (paths) { + *tracing_data_->first_mismatch = paths.value(); + } else { + GetPathsFromAttrAddressesAndStoreMismatch(&lhs, &rhs, tracing_data_); + } + } + + return false; } const ObjectPathPair& SEqualReducer::GetCurrentObjectPaths() const { diff --git a/src/script/ir_builder/base.cc b/src/script/ir_builder/base.cc index 8303efff4f20..879db4f3d713 100644 --- a/src/script/ir_builder/base.cc +++ b/src/script/ir_builder/base.cc @@ -77,6 +77,11 @@ IRBuilder IRBuilder::Current() { return stack->back(); } +bool IRBuilder::IsInScope() { + std::vector* stack = ThreadLocalBuilderStack(); + return !stack->empty(); +} + namespace details { Namer::FType& Namer::vtable() { @@ -106,6 +111,7 @@ TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilder").set_body_typed([]() { return TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderEnter").set_body_method(&IRBuilder::EnterWithScope); TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderExit").set_body_method(&IRBuilder::ExitWithScope); TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderCurrent").set_body_typed(IRBuilder::Current); +TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderIsInScope").set_body_typed(IRBuilder::IsInScope); TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderGet") .set_body_method(&IRBuilderNode::Get); TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderName").set_body_typed(IRBuilder::Name); diff --git a/src/script/ir_builder/ir/frame.cc b/src/script/ir_builder/ir/frame.cc index a81c56922dff..92470ec65342 100644 --- a/src/script/ir_builder/ir/frame.cc +++ b/src/script/ir_builder/ir/frame.cc @@ -26,15 +26,20 @@ namespace ir_builder { namespace ir { void IRModuleFrameNode::ExitWithScope() { - ICHECK_EQ(functions.size(), global_vars.size()); - int n = functions.size(); Map func_map; - for (int i = 0; i < n; ++i) { - func_map.Set(global_vars[i], functions[i]); + CHECK_EQ(functions.size(), global_var_map.size()) + << "All functions must be defined in the IRModule. Got " << global_var_map.size() + << "declared function(s), but only " << functions.size() << "defined function(s)."; + for (const auto& kv : functions) { + const GlobalVar& gv = kv.first; + const BaseFunc& func = kv.second; + CHECK(func.defined()) << "ValueError: function " << gv->name_hint << " is not defined"; + func_map.Set(gv, func); } IRBuilder builder = IRBuilder::Current(); ICHECK(!builder->result.defined()) << "ValueError: Builder.result has already been set"; - builder->result = tvm::IRModule(func_map); + auto dict_attrs = attrs.empty() ? NullValue() : DictAttrs(attrs); + builder->result = tvm::IRModule(func_map, {}, {}, {}, dict_attrs); } TVM_REGISTER_NODE_TYPE(IRModuleFrameNode); diff --git a/src/script/ir_builder/ir/ir.cc b/src/script/ir_builder/ir/ir.cc index a8cc452e4f0c..0c34f85246c9 100644 --- a/src/script/ir_builder/ir/ir.cc +++ b/src/script/ir_builder/ir/ir.cc @@ -20,6 +20,8 @@ #include #include +#include "./utils.h" + namespace tvm { namespace script { namespace ir_builder { @@ -27,12 +29,52 @@ namespace ir { IRModuleFrame IRModule() { ObjectPtr n = make_object(); - n->global_vars.clear(); + n->global_var_map.clear(); n->functions.clear(); return IRModuleFrame(n); } +GlobalVar DeclFunction(const String& func_name, const BaseFunc& func_signature) { + IRModuleFrame frame = FindModuleFrame("I.DeclFunction"); + CHECK(!frame->global_var_map.count(func_name)) + << "ValueError: function " << func_name << " already exists"; + GlobalVar gv = GlobalVar(func_name); + CHECK(frame->functions.find(gv) == frame->functions.end()) + << "ValueError: function " << func_name << " has already been defined."; + frame->global_var_map.Set(func_name, gv); + if (func_signature.defined()) { + frame->functions.Set(gv, func_signature); + } + return gv; +} + +void DefFunction(const String& func_name, const BaseFunc& func) { + IRModuleFrame frame = FindModuleFrame("I.DefFunction"); + auto it = frame->global_var_map.find(func_name); + CHECK(it != frame->global_var_map.end()) + << "ValueError: function " << func_name << " does not exist, please declare it first."; + const GlobalVar& gv = (*it).second; + frame->functions.Set(gv, func); + if (func->checked_type_.defined()) { + gv->checked_type_ = func->checked_type_; + } +} + +void ModuleAttrs(Map attrs) { + if (IRBuilder::IsInScope()) { + // TODO(hongyi): add comments to explain why we need to check if the module frame is in scope + IRModuleFrame frame = FindModuleFrame("I.ModuleAttr"); + if (!frame->attrs.empty()) { + LOG(FATAL) << "ValueError: Duplicate module attrs, previous one is:\n" << frame->attrs; + } + frame->attrs = attrs; + } +} + TVM_REGISTER_GLOBAL("script.ir_builder.ir.IRModule").set_body_typed(IRModule); +TVM_REGISTER_GLOBAL("script.ir_builder.ir.DeclFunction").set_body_typed(DeclFunction); +TVM_REGISTER_GLOBAL("script.ir_builder.ir.DefFunction").set_body_typed(DefFunction); +TVM_REGISTER_GLOBAL("script.ir_builder.ir.ModuleAttrs").set_body_typed(ModuleAttrs); } // namespace ir } // namespace ir_builder diff --git a/src/script/ir_builder/ir/utils.h b/src/script/ir_builder/ir/utils.h new file mode 100644 index 000000000000..58d5e53f7032 --- /dev/null +++ b/src/script/ir_builder/ir/utils.h @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_SCRIPT_IR_BUILDER_IR_UTILS_H_ +#define TVM_SCRIPT_IR_BUILDER_IR_UTILS_H_ + +#include + +namespace tvm { +namespace script { +namespace ir_builder { +namespace ir { + +inline IRModuleFrame FindModuleFrame(const String& method) { + IRBuilder builder = IRBuilder::Current(); + if (Optional frame = builder->FindFrame()) { + const Optional& last_module_frame = builder->GetLastFrame(); + if (last_module_frame.defined() && last_module_frame.value() == frame) { + return frame.value(); + } + } else { + LOG(FATAL) << "ValueError: IRModule frame not find. Please ensure '" << method + << "' is called under I.ir_module()"; + } + LOG(FATAL) << "ValueError: '" << method << "' must be called immediately under I.ir_module()"; + throw; +} + +} // namespace ir +} // namespace ir_builder +} // namespace script +} // namespace tvm + +#endif // TVM_SCRIPT_IR_BUILDER_IR_UTILS_H_ diff --git a/src/script/ir_builder/tir/frame.cc b/src/script/ir_builder/tir/frame.cc index 1e63201a40dd..dd8d3c2ed3f3 100644 --- a/src/script/ir_builder/tir/frame.cc +++ b/src/script/ir_builder/tir/frame.cc @@ -16,6 +16,7 @@ * specific language governing permissions and limitations * under the License. */ +#include #include #include @@ -41,9 +42,17 @@ void PrimFuncFrameNode::ExitWithScope() { ICHECK(!builder->result.defined()) << "ValueError: Builder.result has already been set"; builder->result = func; } else if (Optional opt_frame = builder->FindFrame()) { - ir::IRModuleFrame frame = opt_frame.value(); - frame->global_vars.push_back(GlobalVar(name.value_or(""))); - frame->functions.push_back(func); + CHECK(name.defined()) << "ValueError: The function name must be defined before exiting the " + "function scope, if it's defined in a Module"; + const ir::IRModuleFrame& frame = opt_frame.value(); + const String& func_name = name.value_or(""); + if (!frame->global_var_map.count(func_name)) { + // Case. First time visiting the function. + ir::DeclFunction(func_name, func); + } + // Define the function. + // Note we do checks to disallow redefinition of functions inside the `DefFunction`. + ir::DefFunction(func_name, func); } else { LOG(FATAL) << "ValueError: Cannot find where to insert PrimFunc"; } diff --git a/src/script/ir_builder/tir/utils.h b/src/script/ir_builder/tir/utils.h index 7ccc132fa1fe..f3b547532cfd 100644 --- a/src/script/ir_builder/tir/utils.h +++ b/src/script/ir_builder/tir/utils.h @@ -87,7 +87,7 @@ inline PrimFuncFrame FindPrimFuncFrame(const String& method) { * \return The top frame of BlockFrame. */ inline BlockFrame FindBlockFrame(const String& method) { - if (Optional frame = IRBuilder::Current()->GetLastFrame()) { + if (Optional frame = IRBuilder::Current()->FindFrame()) { return frame.value(); } else if (Optional frame = IRBuilder::Current()->FindFrame()) { LOG(FATAL) << "ValueError: " << method << " must be called at the top of a T.block(). " diff --git a/src/script/printer/ir/ir.cc b/src/script/printer/ir/ir.cc index 065cfe5168ad..1c751d40f2e7 100644 --- a/src/script/printer/ir/ir.cc +++ b/src/script/printer/ir/ir.cc @@ -64,6 +64,11 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) std::sort(functions.begin(), functions.end()); With f(d); (*f)->AddDispatchToken(d, "ir"); + if (mod->attrs.defined() && !mod->attrs->dict.empty()) { + (*f)->stmts.push_back( + ExprStmtDoc(IR(d, "module_attrs") // + ->Call({d->AsDoc(mod->attrs, p->Attr("attrs"))}))); + } for (const auto& entry : functions) { const GlobalVar& gv = entry.gv; const BaseFunc& func = entry.func; diff --git a/src/target/build_common.h b/src/target/build_common.h index 35b3d92eb814..7c9ad8cb3c68 100644 --- a/src/target/build_common.h +++ b/src/target/build_common.h @@ -50,15 +50,9 @@ inline std::unordered_map ExtractFuncInfo(co for (size_t i = 0; i < f->params.size(); ++i) { info.arg_types.push_back(f->params[i].dtype()); } - if (auto opt = f->GetAttr>(tir::attr::kDeviceThreadAxis)) { - auto thread_axis = opt.value(); - for (size_t i = 0; i < thread_axis.size(); ++i) { - info.launch_param_tags.push_back(thread_axis[i]->thread_tag); - } - } - if (auto opt = f->GetAttr(tir::attr::kDeviceUseDynSharedMemory)) { - if (opt.value().IntValue() != 0) { - info.launch_param_tags.push_back(runtime::launch_param::kUseDynamicSharedMemoryTag); + if (auto opt = f->GetAttr>(tir::attr::kKernelLaunchParams)) { + for (const auto& tag : opt.value()) { + info.launch_param_tags.push_back(tag); } } auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); diff --git a/src/target/source/codegen_metal.cc b/src/target/source/codegen_metal.cc index 534e2c3654c4..36ef44bc4814 100644 --- a/src/target/source/codegen_metal.cc +++ b/src/target/source/codegen_metal.cc @@ -130,12 +130,14 @@ void CodeGenMetal::AddFunction(const PrimFunc& f) { ICHECK_EQ(name_supply_->FreshName("threadIdx"), "threadIdx"); ICHECK_EQ(name_supply_->FreshName("blockIdx"), "blockIdx"); int work_dim = 0; - auto thread_axis = f->GetAttr>(tir::attr::kDeviceThreadAxis).value(); - - for (IterVar iv : thread_axis) { - runtime::ThreadScope scope = runtime::ThreadScope::Create(iv->thread_tag); - work_dim = std::max(work_dim, scope.dim_index + 1); + auto launch_params = f->GetAttr>(tir::attr::kKernelLaunchParams).value(); + for (const auto& tag : launch_params) { + if (tag != runtime::launch_param::kUseDynamicSharedMemoryTag) { + runtime::ThreadScope scope = runtime::ThreadScope::Create(tag); + work_dim = std::max(work_dim, scope.dim_index + 1); + } } + if (work_dim != 0) { // use ushort by default for now stream << " "; @@ -145,16 +147,6 @@ void CodeGenMetal::AddFunction(const PrimFunc& f) { PrintType(DataType::UInt(thread_index_bits_, work_dim), stream); stream << " threadIdx [[thread_position_in_threadgroup]]\n"; } - // bind thread axis - for (IterVar iv : thread_axis) { - ICHECK(!var_idmap_.count(iv->var.get())); - std::string vname = iv->thread_tag; - if (work_dim <= 1) { - vname = vname.substr(0, iv->thread_tag.length() - 2); - } - var_idmap_[iv->var.get()] = - CastFromTo(vname, DataType::UInt(thread_index_bits_), iv->var.dtype()); - } // the function scope. stream << ") {\n"; int func_scope = this->BeginScope(); diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index e4569898f7ed..3b07bd943973 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -387,6 +387,18 @@ TVM_REGISTER_NODE_TYPE(PrefetchNode); // SeqStmt SeqStmt::SeqStmt(Array seq, Span span) { + bool requires_flattening = std::any_of( + seq.begin(), seq.end(), [](const Stmt& stmt) { return stmt->IsInstance(); }); + + if (requires_flattening) { + auto flattened = SeqStmt::Flatten(seq); + if (auto* ptr = flattened.as()) { + seq = ptr->seq; + } else { + seq = {flattened}; + } + } + auto node = make_object(); node->seq = std::move(seq); node->span = std::move(span); diff --git a/src/tir/ir/stmt_functor.cc b/src/tir/ir/stmt_functor.cc index f5063b222b9b..7c693b7efcf7 100644 --- a/src/tir/ir/stmt_functor.cc +++ b/src/tir/ir/stmt_functor.cc @@ -439,9 +439,7 @@ Stmt StmtMutator::VisitStmt_(const SeqStmtNode* op) { if (seq.same_as(op->seq)) { return GetRef(op); } else { - auto n = CopyOnWrite(op); - n->seq = std::move(seq); - return Stmt(n); + return SeqStmt(seq); } } diff --git a/src/tir/ir/transform.cc b/src/tir/ir/transform.cc index 4c59a1767372..38ec16ea3755 100644 --- a/src/tir/ir/transform.cc +++ b/src/tir/ir/transform.cc @@ -88,7 +88,8 @@ PrimFuncPass::PrimFuncPass( // Perform Module -> Module optimizations at the PrimFunc level. IRModule PrimFuncPassNode::operator()(IRModule mod, const PassContext& pass_ctx) const { ICHECK(mod.defined()); - std::vector deleted_list; + std::vector deleted_list; + IRModuleNode* mod_ptr = mod.CopyOnWrite(); auto* func_dict = mod_ptr->functions.CopyOnWrite(); // directly loop over the underlying dict @@ -101,14 +102,16 @@ IRModule PrimFuncPassNode::operator()(IRModule mod, const PassContext& pass_ctx) kv.second = std::move(func); if (!kv.second.defined()) { - deleted_list.push_back(kv.first); + deleted_list.push_back(Downcast(kv.first)); } } } - // automatic removal of None + // Automatic removal of None. This uses IRModuleNode::Remove + // instead of manipulating func_dict directly, to ensure that both + // the function map and the global_var_map_ are correctly updated. for (const auto& gv : deleted_list) { - func_dict->erase(gv); + mod_ptr->Remove(gv); } return mod; } diff --git a/src/tir/transforms/ir_utils.cc b/src/tir/transforms/ir_utils.cc index f6e4ac45c612..3d9f19c153e4 100644 --- a/src/tir/transforms/ir_utils.cc +++ b/src/tir/transforms/ir_utils.cc @@ -26,6 +26,7 @@ #include #include #include +#include #include #include @@ -90,13 +91,106 @@ Stmt MergeNest(const std::vector>& nest, Stmt body) { class IRConvertSSA final : public StmtExprMutator { public: - PrimExpr VisitExpr_(const VarNode* op) final { - if (scope_.count(op) && !scope_[op].empty()) { - return scope_[op].back(); - } else { - return GetRef(op); + PrimFunc VisitPrimFunc(PrimFunc func) { + std::vector redefines; + + // Remap parameters, if they were used in another function + auto params = func->params.Map([&](const tir::Var& var) -> tir::Var { + if (defined_.count(var.get())) { + const ScopedRedefine& redefine = redefines.emplace_back(this, var); + return redefine.new_var; + } else { + defined_.insert(var.get()); + return var; + } + }); + + // Remap implicitly defined buffer parameters + { + std::unordered_set defined_params; + for (const auto& var : func->params) { + defined_params.insert(var.get()); + } + for (const auto& [var, buffer] : func->buffer_map) { + auto check_expr = [&](const PrimExpr& expr) { + auto* var_ptr = expr.as(); + if (!var_ptr) return; + if (defined_params.count(var_ptr)) return; + + if (defined_.count(var_ptr)) { + auto var = GetRef(var_ptr); + redefines.emplace_back(this, var); + } else { + defined_.insert(var_ptr); + } + }; + for (const auto& dim : buffer->shape) { + check_expr(dim); + } + for (const auto& stride : buffer->strides) { + check_expr(stride); + } + check_expr(buffer->elem_offset); + } + } + + // Update the buffer map, based on the redefined parameters + auto buffer_map = [&]() { + Map buffer_map; + bool made_change = false; + for (const auto& [var, buffer] : func->buffer_map) { + auto new_var = GetRemappedVar(var); + auto new_buf = GetRemappedBuffer(buffer); + + made_change = made_change || !var.same_as(new_var) || !buffer.same_as(new_buf); + buffer_map.Set(new_var, new_buf); + } + if (made_change) { + return buffer_map; + } else { + return func->buffer_map; + } + }(); + + auto attrs = [&]() -> DictAttrs { + Map dict; + bool made_change = false; + + for (const auto& [key, old_value] : func->attrs->dict) { + auto value = old_value; + if (auto* expr = value.as()) { + value = VisitExpr(GetRef(expr)); + } else if (auto* stmt = value.as()) { + value = VisitStmt(GetRef(stmt)); + } + + made_change = made_change || !value.same_as(old_value); + dict.Set(key, value); + } + + if (made_change) { + return DictAttrs(dict); + } else { + return func->attrs; + } + }(); + + auto body = VisitStmt(func->body); + + // If anything changed, update the returned function + if (!params.same_as(func->params) || !buffer_map.same_as(func->buffer_map) || + !attrs.same_as(func->attrs) || !body.same_as(func->body)) { + func = PrimFunc(params, body, func->ret_type, buffer_map, attrs); + } + + // Pop the redefines in reverse order of creation + while (redefines.size()) { + redefines.pop_back(); } + return func; } + + PrimExpr VisitExpr_(const VarNode* op) final { return GetRemappedVar(GetRef(op)); } PrimExpr VisitExpr_(const LetNode* op) final { const Var& v = op->var; if (defined_.count(v.get())) { @@ -142,18 +236,27 @@ class IRConvertSSA final : public StmtExprMutator { return node; } + Var GetRemappedVar(Var var) { + if (auto it = scope_.find(var.get()); it != scope_.end() && it->second.size()) { + return it->second.back(); + } else { + return var; + } + } + Buffer GetRemappedBuffer(Buffer buf) { // Determine the buffer var that should be in the updated buffer, // given the current scope. If no redefines are present, then the // buffer var is unchanged. - Var new_buffer_var = buf->data; - auto var_it = scope_.find(buf->data.get()); - if (var_it != scope_.end() && !var_it->second.empty()) { - new_buffer_var = var_it->second.back(); - } + Var new_buffer_var = GetRemappedVar(buf->data); + PrimExpr elem_offset = VisitExpr(buf->elem_offset); + auto visit_expr = [this](const PrimExpr& expr) { return VisitExpr(expr); }; + Array shape = buf->shape.Map(visit_expr); + Array strides = buf->strides.Map(visit_expr); // If no mapping is required, return the original buffer. - if (new_buffer_var.same_as(buf->data)) { + if (new_buffer_var.same_as(buf->data) && elem_offset.same_as(buf->elem_offset) && + shape.same_as(buf->shape) && strides.same_as(buf->strides)) { return buf; } @@ -169,9 +272,9 @@ class IRConvertSSA final : public StmtExprMutator { // new buffer, pushing it onto the scoped stack of existing // buffers. This will be popped when the new_buffer_var // redefinition is popped. - Buffer new_buf(new_buffer_var, buf->dtype, buf->shape, buf->strides, buf->elem_offset, - buf->name, buf->data_alignment, buf->offset_factor, buf->buffer_type, - buf->axis_separators, buf->span); + Buffer new_buf(new_buffer_var, buf->dtype, shape, strides, elem_offset, buf->name, + buf->data_alignment, buf->offset_factor, buf->buffer_type, buf->axis_separators, + buf->span); buffers.push_back(new_buf); return new_buf; } @@ -239,16 +342,33 @@ class IRConvertSSA final : public StmtExprMutator { } ~ScopedRedefine() { - parent->scope_[old_var.get()].pop_back(); - for (auto& kv : parent->buf_remap_) { - std::vector& buffers = kv.second; - if (buffers.size() && (buffers.back()->data.get() == new_var.get())) { - buffers.pop_back(); + if (parent) { + parent->scope_[old_var.get()].pop_back(); + for (auto& kv : parent->buf_remap_) { + std::vector& buffers = kv.second; + if (buffers.size() && (buffers.back()->data.get() == new_var.get())) { + buffers.pop_back(); + } } } } - IRConvertSSA* parent; + ScopedRedefine& operator=(const ScopedRedefine&) = delete; + ScopedRedefine(const ScopedRedefine&) = delete; + + ScopedRedefine& operator=(ScopedRedefine&& other) { + swap(other); + return *this; + } + ScopedRedefine(ScopedRedefine&& other) { swap(other); } + + void swap(ScopedRedefine& other) { + std::swap(parent, other.parent); + std::swap(old_var, other.old_var); + std::swap(new_var, other.new_var); + } + + IRConvertSSA* parent{nullptr}; Var old_var; Var new_var; }; @@ -447,5 +567,30 @@ std::pair GetAsyncWaitAttributes(const AttrStmtNode* op) { return std::make_pair(op->value, inner->value); } +namespace transform { +Pass ConvertSSA() { + auto pass_func = [](IRModule mod, PassContext ctx) { + tir::IRConvertSSA converter; + Map functions; + bool made_change = false; + for (auto [gvar, base_func] : mod->functions) { + if (auto* ptr = base_func.as()) { + auto updated = converter.VisitPrimFunc(GetRef(ptr)); + if (!updated.same_as(base_func)) { + made_change = true; + base_func = updated; + } + } + functions.Set(gvar, base_func); + } + if (made_change) { + mod.CopyOnWrite()->functions = std::move(functions); + } + return mod; + }; + return tvm::transform::CreateModulePass(pass_func, 0, "tir.ConvertSSA", {}); +} + +} // namespace transform } // namespace tir } // namespace tvm diff --git a/src/tir/transforms/make_packed_api.cc b/src/tir/transforms/make_packed_api.cc index b12d3dd49f05..ed3a2da19613 100644 --- a/src/tir/transforms/make_packed_api.cc +++ b/src/tir/transforms/make_packed_api.cc @@ -266,7 +266,7 @@ PrimFunc MakePackedAPI(PrimFunc&& func) { StringImm(name_hint + "_compute_"), body); // Set device context if (vmap.count(device_id.get())) { - PrimExpr node = StringImm("default"); + ObjectRef node = String("default"); seq_check.push_back(AttrStmt(node, attr::device_id, device_id, nop)); seq_check.push_back(AttrStmt(node, attr::device_type, device_type, nop)); diff --git a/src/tir/transforms/remap_thread_axis.cc b/src/tir/transforms/remap_thread_axis.cc index e101e6b904ce..519a3e1f80d8 100644 --- a/src/tir/transforms/remap_thread_axis.cc +++ b/src/tir/transforms/remap_thread_axis.cc @@ -69,26 +69,29 @@ class ThreadAxisRewriter : private StmtExprMutator { std::unordered_map vmap_; }; -PrimFunc RemapThreadAxis(PrimFunc&& f, Map thread_map) { +PrimFunc RemapThreadAxis(PrimFunc func, Map thread_map) { std::unordered_map tmap; for (const auto& kv : thread_map) { tmap[kv.first] = kv.second; } - auto opt_thread_axis = f->GetAttr>(tir::attr::kDeviceThreadAxis); - ICHECK(opt_thread_axis != nullptr) << "Require attribute " << tir::attr::kDeviceThreadAxis; - auto thread_axis = opt_thread_axis.value(); - auto* n = f.CopyOnWrite(); - - // replace the thread axis - for (size_t i = 0; i < thread_axis.size(); ++i) { - auto it = tmap.find(thread_axis[i]->thread_tag); - if (it != tmap.end()) { - thread_axis.Set(i, it->second); + if (auto opt = func->GetAttr>(tir::attr::kKernelLaunchParams)) { + ICHECK(opt != nullptr) << "Require attribute " << tir::attr::kKernelLaunchParams; + auto launch_params = opt.value(); + // replace the thread axis attribute + for (size_t i = 0; i < launch_params.size(); ++i) { + auto it = tmap.find(launch_params[i]->thread_tag); + if (it != tmap.end()) { + launch_params.Set(i, it->second); + } } + + func = WithAttr(std::move(func), tir::attr::kKernelLaunchParams, launch_params); } + + auto* n = func.CopyOnWrite(); n->body = ThreadAxisRewriter(tmap).Rewrite(std::move(n->body)); - return WithAttr(std::move(f), tir::attr::kDeviceThreadAxis, thread_axis); + return func; } namespace transform { diff --git a/src/tir/transforms/split_host_device.cc b/src/tir/transforms/split_host_device.cc index c43fc403ed94..4f47b8ce2bf9 100644 --- a/src/tir/transforms/split_host_device.cc +++ b/src/tir/transforms/split_host_device.cc @@ -51,6 +51,17 @@ class DeviceInfoCollector : public StmtVisitor { PrimExpr dyn_shmem_size_{0}; bool use_dyn_shmem_{false}; + Array GetLaunchParams() const { + Array output; + for (const auto& axis : thread_axis_) { + output.push_back(axis->thread_tag); + } + if (use_dyn_shmem_) { + output.push_back(runtime::launch_param::kUseDynamicSharedMemoryTag); + } + return output; + } + private: void VisitStmt_(const AttrStmtNode* op) final { if (op->attr_key == attr::thread_extent) { @@ -199,8 +210,9 @@ class HostDeviceSplitter : public StmtMutator { GlobalVar kernel_symbol_global = global_var_supply->FreshGlobal(kernel_symbol, false); PrimFunc device_func(params, Substitute(body, remap_vars)); - device_func = - WithAttr(std::move(device_func), tir::attr::kDeviceThreadAxis, dev_info.thread_axis_); + device_func = WithAttr(std::move(device_func), tir::attr::kKernelLaunchParams, + dev_info.GetLaunchParams()); + device_func = WithAttr(std::move(device_func), tvm::attr::kCallingConv, Integer(CallingConv::kDeviceKernelLaunch)); device_func = WithAttr(std::move(device_func), tvm::attr::kGlobalSymbol, @@ -208,10 +220,7 @@ class HostDeviceSplitter : public StmtMutator { device_func = WithAttr(std::move(device_func), tir::attr::kNoAlias, Integer(1)); device_func = WithAttr(std::move(device_func), tvm::attr::kTarget, device_target_); device_func = WithAttr(std::move(device_func), tir::attr::kIsGlobalFunc, Integer(1)); - if (dev_info.use_dyn_shmem_) { - device_func = - WithAttr(std::move(device_func), tir::attr::kDeviceUseDynSharedMemory, Integer(1)); - } + (*device_mod_)->Add(kernel_symbol_global, device_func); // generate calls to the device function @@ -273,7 +282,7 @@ Pass SplitHostDevice() { } } mod->Update(device_mod); - return mod; + return ConvertSSA()(mod); }; return tvm::transform::CreateModulePass(pass_func, 0, "tir.SplitHostDevice", {}); diff --git a/tests/python/unittest/test_tir_transform_helpers.py b/tests/python/unittest/test_tir_transform_helpers.py index f8dc0f682d06..657bda591ae2 100644 --- a/tests/python/unittest/test_tir_transform_helpers.py +++ b/tests/python/unittest/test_tir_transform_helpers.py @@ -17,7 +17,7 @@ import pytest import tvm -from tvm.script import tir as T +from tvm.script import tir as T, ir as I import tvm.testing @@ -119,5 +119,33 @@ def checker_filter_out_both(func: tvm.tir.PrimFunc): assert len(after.functions) == 0 +class TestFilterRemovesGlobalVarMap(tvm.testing.CompareBeforeAfter): + """Filtering out a function should be identical to never adding it + + This test is to guard against hidden state in the IRModule that + remains after filtering. Previously, this was observed in the + `IRModuleNode::global_var_map_`, which retained entries of + filtered-out functions. + """ + + transform = tvm.tir.transform.Filter(lambda prim_func: False) + + def before(self): + @I.ir_module + class module: + @T.prim_func + def func(): + T.evaluate(0) + + return module + + def expected(self): + @I.ir_module + class module: + pass + + return module + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/unittest/test_tir_transform_split_host_device.py b/tests/python/unittest/test_tir_transform_split_host_device.py index f4adac9cf742..680f23e07a17 100644 --- a/tests/python/unittest/test_tir_transform_split_host_device.py +++ b/tests/python/unittest/test_tir_transform_split_host_device.py @@ -17,6 +17,7 @@ import tvm from tvm import te import tvm.testing +from tvm.script import tir as T, ir as I @tvm.testing.requires_cuda @@ -48,5 +49,29 @@ def test_split_host_device_func_attr(): assert fdevice.attrs["tir.is_global_func"].value +def test_ssa_across_entire_module(): + """The host and device functions should not share TIR vars + + Any arguments that are passed from the host to the device should + be in terms of independent TIR variables. + """ + + @I.ir_module + class before: + @T.prim_func + def main(): + T.func_attr({"global_symbol": "main", "target": T.target("cuda")}) + for i in range(16): + T.attr(0, "device_scope", 0) + for j in range(16): + T.evaluate(i) + + after = tvm.tir.transform.SplitHostDevice()(before) + loop_var = after["main"].body.loop_var + param_var = after["main_kernel0"].params[0] + + assert not loop_var.same_as(param_var) + + if __name__ == "__main__": test_split_host_device_func_attr() diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py b/tests/python/unittest/test_tvmscript_roundtrip.py index cd7f1726c9d9..c02fcb852342 100644 --- a/tests/python/unittest/test_tvmscript_roundtrip.py +++ b/tests/python/unittest/test_tvmscript_roundtrip.py @@ -21,7 +21,8 @@ import tvm import tvm.testing from tvm import tir -from tvm.script import tir as T +from tvm.script import tir as T, ir as I +from tvm.ir.instrument import pass_instrument import numpy as np @@ -181,6 +182,20 @@ def main(inputs: T.Buffer((64, 2, 4), "float32")) -> None: return main +def copy_using_env_thread(): + shape = (64, 2, 4) + + @T.prim_func + def func(A: T.Buffer(shape), B: T.Buffer(shape)): + blocks, M, N = T.meta_var(shape) + + bx = T.launch_thread("blockIdx.x", blocks) + for i, j in T.grid(M, N): + B[bx, i, j] = A[bx, i, j] + + return func + + def opt_gemm_mod_host(): @tvm.script.ir_module class Module: @@ -3332,6 +3347,25 @@ def func(): return func +def test_void_ptr_vs_handle(): + """Distinguish between void* and handle + + In the future, perhaps these should be de-duplicated by forbidding + one of the two C++ representations. + """ + # Generates PointerType(PrimType(DataType::Void())) + @T.prim_func + def void_ptr(out_ret_value: T.handle("void")): + T.evaluate(out_ret_value) + + # Generates PrimType(DataType::Handle()) + @T.prim_func + def handle(out_ret_value: T.handle): + T.evaluate(out_ret_value) + + assert not tvm.ir.structural_equal(void_ptr, handle) + + def void_ptr(): @T.prim_func def func(out_ret_value: T.handle("void")): @@ -3692,6 +3726,76 @@ def func( return func +def ir_module_with_attrs(): + @I.ir_module + class Module: + I.module_attrs({"attr": 10}) + + @T.prim_func + def tir_func(A: T.Buffer(16, "int32"), B: T.Buffer(16, "int32")): + for i in range(16): + B[i] = A[i] + + return Module + + +def tvm_struct_set_generated_in_cpp(): + """Ensure same dtype for tvm_struct_set in Python/C++ + + The TVMStructSet method in C++, used internally by + LowerTVMBuiltin, and the Python method `T.tvm_struct_set`, used + when parsing TVMScript should use the same dtype "int32". + """ + + @I.ir_module + class Module: + @T.prim_func + def tir_packed_call(A: T.Buffer(16)): + T.attr(0, "device_id", 0) + T.attr(0, "device_type", 0) + T.evaluate( + T.tvm_call_cpacked( + "tvm_test_cpacked", + T.tvm_stack_make_array( + A.data, + T.tvm_stack_make_shape(16, dtype="handle"), + T.reinterpret(T.uint64(0), dtype="handle"), + T.uint32(1), + T.Cast("float32", 0), + 0, + dtype="handle", + ), + dtype="int32", + ) + ) + + return tvm.tir.transform.LowerTVMBuiltin()(Module) + + +def nested_seqstmt(): + """Nested SeqStmt should be normalized to flat SeqStmt + + Nested SeqStmt are representable in the TIR structures, but are + flattened when converted to TVMScript. Previously, this could + cause failures to round-trip through TVMScript, including + erroneous use of TVMScript's concise-scoping rules. This was + resolved by normalizing nested SeqStmt in TIR, such that the use + of `tir.SeqStmt` below results in a single flat `tir.SeqStmt` + containing the three `tir.Evaluate` calls. + """ + func = tvm.tir.PrimFunc( + params=[], + body=tvm.tir.SeqStmt( + [ + tvm.tir.SeqStmt([tvm.tir.Evaluate(0), tvm.tir.Evaluate(1)]), + tvm.tir.Evaluate(2), + ] + ), + ) + + return func + + ir_generator = tvm.testing.parameter( launch_env_thread, opt_gemm_normalize, @@ -3757,6 +3861,9 @@ def func( merge_shape_var_def, if_then_else_var, tvm_shfl_builtins, + ir_module_with_attrs, + tvm_struct_set_generated_in_cpp, + nested_seqstmt, ) @@ -3772,5 +3879,35 @@ def test_return_none_no_trailing_type(): assert "-> None" not in script +@pass_instrument +class ValidateTVMScriptRoundTrip: + def run_after_pass(self, mod, info): + after_roundtrip = tvm.script.from_source(mod.script(show_meta=True)) + tvm.ir.assert_structural_equal(mod, after_roundtrip, True) + + +@pytest.mark.parametrize( + "generator,target", + [ + (matmul, "llvm"), + pytest.param( + launch_env_thread, + "cuda", + marks=tvm.testing.Feature("cuda").marks(support_required="compile-only"), + ), + pytest.param( + copy_using_env_thread, + "cuda", + marks=tvm.testing.Feature("cuda").marks(support_required="compile-only"), + ), + ], +) +def test_roundtrip_all_lowering_steps(generator, target): + func = generator() + + with tvm.transform.PassContext(instruments=[ValidateTVMScriptRoundTrip()]): + tvm.build(func, target=target) + + if __name__ == "__main__": tvm.testing.main()