From 96e05545326970c52d72603f6e27d3d51025dcda Mon Sep 17 00:00:00 2001 From: YuchenJin Date: Mon, 6 Feb 2023 13:43:41 -0800 Subject: [PATCH 1/2] Relax BlockBuilder and ExprMutator --- CMakeLists.txt | 2 + include/tvm/relax/analysis.h | 13 + include/tvm/relax/block_builder.h | 239 +++ include/tvm/relax/expr.h | 7 +- include/tvm/relax/expr_functor.h | 138 +- include/tvm/relax/op_attr_types.h | 75 + include/tvm/relax/struct_info.h | 5 +- include/tvm/relax/utils.h | 155 ++ include/tvm/te/operation.h | 2 +- python/tvm/ir/function.py | 16 + python/tvm/meta_schedule/utils.py | 50 +- python/tvm/relax/__init__.py | 11 + python/tvm/relax/analysis/analysis.py | 33 +- python/tvm/relax/block_builder.py | 801 ++++++++++ python/tvm/relax/expr_functor.py | 1530 ++++++++++++++++++++ python/tvm/relax/op/__init__.py | 22 + python/tvm/relax/op/_ffi_api.py | 19 + python/tvm/relax/op/base.py | 358 +++++ python/tvm/relax/op/binary.py | 67 + python/tvm/relax/utils.py | 278 ++++ python/tvm/te/__init__.py | 1 + python/tvm/te/operation.py | 56 +- src/ir/function.cc | 14 + src/relax/analysis/struct_info_analysis.cc | 149 ++ src/relax/ir/block_builder.cc | 969 +++++++++++++ src/relax/ir/emit_te.cc | 78 + src/relax/ir/emit_te.h | 68 + src/relax/ir/expr_functor.cc | 244 ++++ src/relax/ir/py_expr_functor.cc | 649 +++++++++ src/relax/op/op.cc | 77 + src/relax/op/op_common.cc | 122 ++ src/relax/op/op_common.h | 285 ++++ src/relax/op/tensor/binary.cc | 87 ++ src/relax/op/tensor/binary.h | 71 + src/relax/utils.cc | 41 + src/te/operation/create_primfunc.cc | 80 + src/te/operation/create_primfunc.h | 17 + tests/python/relax/test_blockbuilder.py | 542 +++++++ tests/python/relax/test_expr.py | 4 +- tests/python/relax/test_expr_functor.py | 746 ++++++++++ 40 files changed, 8092 insertions(+), 29 deletions(-) create mode 100644 include/tvm/relax/block_builder.h create mode 100644 include/tvm/relax/op_attr_types.h create mode 100644 include/tvm/relax/utils.h create mode 100644 python/tvm/relax/block_builder.py create mode 100644 python/tvm/relax/expr_functor.py create mode 100644 python/tvm/relax/op/__init__.py create mode 100644 python/tvm/relax/op/_ffi_api.py create mode 100644 python/tvm/relax/op/base.py create mode 100644 python/tvm/relax/op/binary.py create mode 100644 python/tvm/relax/utils.py create mode 100644 src/relax/ir/block_builder.cc create mode 100644 src/relax/ir/emit_te.cc create mode 100644 src/relax/ir/emit_te.h create mode 100644 src/relax/ir/py_expr_functor.cc create mode 100644 src/relax/op/op.cc create mode 100644 src/relax/op/op_common.cc create mode 100644 src/relax/op/op_common.h create mode 100644 src/relax/op/tensor/binary.cc create mode 100644 src/relax/op/tensor/binary.h create mode 100644 src/relax/utils.cc create mode 100644 tests/python/relax/test_blockbuilder.py create mode 100644 tests/python/relax/test_expr_functor.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 78f03045e7c4..828d5ed706f3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -290,8 +290,10 @@ tvm_file_glob(GLOB_RECURSE COMPILER_SRCS src/support/*.cc src/script/*.cc src/relax/ir/*.cc + src/relax/op/*.cc src/relax/analysis/*.cc src/relax/backend/vm/*.cc + src/relax/utils.cc ) tvm_file_glob(GLOB CODEGEN_SRCS diff --git a/include/tvm/relax/analysis.h b/include/tvm/relax/analysis.h index 82145032f458..6125171598b3 100644 --- a/include/tvm/relax/analysis.h +++ b/include/tvm/relax/analysis.h @@ -85,6 +85,19 @@ TVM_DLL Type GetStaticType(const StructInfo& info); */ TVM_DLL StructInfo StructInfoFromType(const Type& type); +/*! + * \return Derive the call's ret value struct info from inputs. + * \param func_info The function struct info. + * \param call The call expression to be derived. + * \param ctx The builder context. + * \param ana Optional context analyzer to prove symbolic expression equality. + * \return The derived struct info of the call. + * \note call->op field is ignored during derivation and we only rely on information + * presented by func_sinfo. + */ +TVM_DLL StructInfo DeriveCallRetStructInfo(const FuncStructInfo& finfo, const Call& call, + const BlockBuilder& ctx, arith::Analyzer* ana = nullptr); + /*! * \brief Erase the info to a corresponding more coarse grained * struct info that is still well-defined(with all the vars in scope). diff --git a/include/tvm/relax/block_builder.h b/include/tvm/relax/block_builder.h new file mode 100644 index 000000000000..d92e5faf279b --- /dev/null +++ b/include/tvm/relax/block_builder.h @@ -0,0 +1,239 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/block_builder.h + * \brief The utility for constructing Relax binding blocks. + */ +#ifndef TVM_RELAX_BLOCK_BUILDER_H_ +#define TVM_RELAX_BLOCK_BUILDER_H_ + +#include +#include +#include +#include + +namespace tvm { +namespace relax { + +/*! + * \brief A builder to build Relax binding blocks. + * + * BlockBuilder provides the following three categories + * of main functionalities for IR building and transformations: + * + * - Global context management: manages the IRModule, + * allowing query, update the surrounding global context. + * Provide context tools for analysis. + * - Scope management: + * - Manages block scopes for bulding nested blocks. + * - Emit bindings to the current scope. + * - Construct blocks by calling EndScope. + * - Normalization: Take an Expr, normalize it + * to deduce shape/type, turn things into normal forms. + * + * Importantly, these three categories of features can be dependent + * on each other. For example, when we emit into scope we will call + * normalize to ensure the code is in normal form. Similarly, when we + * normalize we could choose to emit into the current context. + * + * We would encourage the developers to keep these three category + * in mind when using and developing BlockBuilder, we can group + * the code in a logically clean way. + * + * BlockBuilderNode is implemented as a virtual interface to + * allow logically grouped implementation and internal data + * structures that are hidden from the users. + */ +class BlockBuilderNode : public Object { + public: + //------------------------------- + // Global Context management + //------------------------------- + /*! + * \brief Get the name table for generating unique names. + * + * \return The name table. + */ + virtual NameTable* name_table() = 0; + + /*! + * \brief Get the context IRModule in this builder. + * + * \note The context + * \return The IRModule in this BlockBuilder. + */ + virtual IRModule GetContextIRModule() const = 0; + + /*! + * \brief Add a Relax function or a TIR PrimFunc to internal context module. + * \param func The function to be added. + * \param func_name_hint The name hint of the function to be added. + * \note If the function to be added already exists, return its + * GlobalVar directly. + * \return The global var bound to the added function. + */ + virtual GlobalVar AddFunction(const BaseFunc& func, String func_name_hint) = 0; + + /*! + * \brief Update a Relax function or a TIR PrimFunc in the internal context module. + * \param gv The global var referring the function to be updated. + * \param function The updated function. + */ + virtual void UpdateFunction(const GlobalVar& gv, BaseFunc function) = 0; + + /*! + * \brief Report an error during transformation construction. + * \param diagnostic The diagnostic information. + */ + virtual void ReportFatal(const Diagnostic& diagnostic) = 0; + + //------------------------------- + // Scope management + //------------------------------- + /*! + * \brief Lookup the binding value that var binds to in the current emitted sequences. + * \param var The input var. + * \return The Expr bound to the input \p var. + * \note For function parameters, this function returns NullOpt. + */ + virtual Optional LookupBinding(const Var& var) = 0; + + /*! + * \brief Begin a new scope, with optional parameters that + * are visible within the scope. + * + * \param params Parameters that are visible within the scope. + * + * \note This function should be called when new scope is introduced + * (function, seq) to properly track the variable availability + * and help the best effort deduction. + * + * \sa EndScope + */ + virtual void BeginScope(Optional> params) = 0; + + /*! \brief End the previously defined scope. */ + virtual void EndScope() = 0; + + /*! \brief Begin to build a DataflowBlock. */ + virtual void BeginDataflowBlock() = 0; + + /*! \brief Begin to build a BindingBlock. */ + virtual void BeginBindingBlock() = 0; + /*! + * \brief End building a BindingBlock. + * \return The BindingBlock being built. + */ + virtual BindingBlock EndBlock() = 0; + + /*! + * \brief Check if the block being built is DataflowBlock or not. + * \return A boolean that indicates if the block being built is DataflowBlock or not. + */ + virtual bool CurrentBlockIsDataFlow() = 0; + + /*! + * \brief Emits an Expr, and returns the variable it is bound to. + * \param expr The Expr to be emitted. + * \param name_hint Name hint for the bound variable. + * \return The new variable that \p expr is bound to. + * + * \note This Emit function normalizes the \p expr, and + * performs shape and type deductions by calling Normalize. + */ + virtual Var Emit(Expr expr, String name_hint = "") = 0; + + /*! + * \brief Emit a MatchCast. + * \param value The input value. + * \param struct_info The struct info to be matched. + * \param name_hint Name hint for the bound variable. + * \return The variable bound to the MatchCast. + */ + virtual Var EmitMatchCast(Expr value, StructInfo struct_info, String name_hint = "") = 0; + + /*! + * \brief Generate an output for the current dataflow block. + * \param output The output variable of the block. + * \param name_hint Name hint for the bound variable. + * \return The variable bound to \p output. + */ + virtual Var EmitOutput(Expr output, String name_hint = "") = 0; + + /*! + * \brief Emit a binding that is already normalized. + * + * \param binding A binding whose value is already normalized. + * + * \note This function requires binding to be pre-normalized. + */ + virtual void EmitNormalized(Binding normalized_binding) = 0; + + /*! + * \brief Convert an expression to normal form, and try to eagerly infer types and shapes. + * \param expr The input expression. + * \return The normalized expression. + * + * \note Invariant: If any of the sub expr have struct_info field. + * they must have already been normalized. + */ + virtual Expr Normalize(const Expr& expr) = 0; + + /*! + * \brief Normalize argument to a call or another IRNode. + * \param expr The input expression. + * \return The normalized expression. + * + * \note This function will create a binding var for non-leaf expressions such as Call. + */ + virtual Expr NormalizeArgument(const Expr& expr) = 0; + + /*! + * \brief Get the analyzer of the BlockBuilder. + * \return The BlockBuilder's arithmetic analyzer. + */ + virtual arith::Analyzer* GetAnalyzer() = 0; + + static constexpr const uint32_t _type_index = TypeIndex::kDynamic; + static constexpr const char* _type_key = "relax.BlockBuilder"; + TVM_DECLARE_BASE_OBJECT_INFO(BlockBuilderNode, Object); +}; + +class BlockBuilder : public ObjectRef { + public: + /*! + * \brief Create a BlockBuilder. + * + * \param ctx_mod Optional before-transformation context module for rewriting. + * \return The created BlockBuilder. + * + * \note When rewriting an existing IRModule, it is important to pass it in as + * ctx_mod so you can lookup the context functions for cross function + * call analysis. + */ + TVM_DLL static BlockBuilder Create(Optional ctx_mod); + + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(BlockBuilder, ObjectRef, BlockBuilderNode); +}; + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_BLOCK_BUILDER_H_ diff --git a/include/tvm/relax/expr.h b/include/tvm/relax/expr.h index 9e563c7061dc..0788193ee7c4 100644 --- a/include/tvm/relax/expr.h +++ b/include/tvm/relax/expr.h @@ -171,8 +171,7 @@ class CallNode : public ExprNode { // skip sinfo_args check for primitive ops. equal->MarkGraphNode(); return equal(op, other->op) && equal(args, other->args) && equal(attrs, other->attrs) && - (IsPrimitiveOp(op) || equal(sinfo_args, other->sinfo_args)) && - equal(struct_info_, other->struct_info_); + equal(sinfo_args, other->sinfo_args) && equal(struct_info_, other->struct_info_); } void SHashReduce(SHashReducer hash_reduce) const { @@ -180,9 +179,7 @@ class CallNode : public ExprNode { hash_reduce(op); hash_reduce(args); hash_reduce(attrs); - if (!IsPrimitiveOp(op)) { - hash_reduce(sinfo_args); - } + hash_reduce(sinfo_args); hash_reduce(struct_info_); } diff --git a/include/tvm/relax/expr_functor.h b/include/tvm/relax/expr_functor.h index 5735e8661f6f..ac3ff8d79376 100644 --- a/include/tvm/relax/expr_functor.h +++ b/include/tvm/relax/expr_functor.h @@ -26,15 +26,18 @@ #define TVM_RELAX_EXPR_FUNCTOR_H_ #include +#include #include #include #include #include #include +#include #include +#include #include - +#include namespace tvm { namespace relax { @@ -410,6 +413,139 @@ class ExprMutatorBase : public ExprFunctor { DefaultStructInfoFieldMutator default_struct_info_field_mutator_{this}; }; +/*! + * \brief A mutator works in normal form. + * + * ExprMutator expects input AST to be in the normal form, i.e., the expressions are normalized(no + * nesting and hence the AST is in ANF), and all checked_type_ and shape_ of expressions are + * available. + */ +class ExprMutator : public ExprMutatorBase { + public: + using ExprMutatorBase::VisitExpr_; + + ExprMutator(Optional mod = NullOpt) { builder_ = BlockBuilder::Create(mod); } + Expr VisitExpr(const Expr& expr) override; + Expr VisitExpr_(const VarNode* op) override; + Expr VisitExpr_(const DataflowVarNode* op) override; + Expr VisitExpr_(const FunctionNode* op) override; + Expr VisitExpr_(const SeqExprNode* op) override; + Expr VisitExpr_(const IfNode* op) override; + + /*! + * \brief Generic dispatcher for bindings. + * \param binding The binding to be visited. + */ + virtual void VisitBinding(const Binding& binding); + // specific leaf level visitor functions + virtual void VisitBinding_(const VarBindingNode* binding); + virtual void VisitBinding_(const MatchCastNode* binding); + // second level dispatching based on binding value type. + // these dispatching functions get called from first-level dispatch on VarBinding + virtual void VisitBinding_(const VarBindingNode* binding, const ConstantNode* val); + virtual void VisitBinding_(const VarBindingNode* binding, const TupleNode* val); + virtual void VisitBinding_(const VarBindingNode* binding, const VarNode* val); + virtual void VisitBinding_(const VarBindingNode* binding, const DataflowVarNode* val); + virtual void VisitBinding_(const VarBindingNode* binding, const ShapeExprNode* val); + virtual void VisitBinding_(const VarBindingNode* binding, const ExternFuncNode* val); + virtual void VisitBinding_(const VarBindingNode* binding, const GlobalVarNode* val); + virtual void VisitBinding_(const VarBindingNode* binding, const FunctionNode* val); + virtual void VisitBinding_(const VarBindingNode* binding, const CallNode* val); + virtual void VisitBinding_(const VarBindingNode* binding, const SeqExprNode* val); + virtual void VisitBinding_(const VarBindingNode* binding, const IfNode* val); + virtual void VisitBinding_(const VarBindingNode* binding, const OpNode* val); + virtual void VisitBinding_(const VarBindingNode* binding, const TupleGetItemNode* val); + virtual void VisitBinding_(const VarBindingNode* binding, const PrimValueNode* val); + virtual void VisitBinding_(const VarBindingNode* binding, const StringImmNode* val); + virtual void VisitBinding_(const VarBindingNode* binding, const DataTypeImmNode* val); + /*! + * \brief Generic dispatcher for binding blocks. + * \param block The binding block to be visited. + * \return The binding block after transformation. + */ + virtual BindingBlock VisitBindingBlock(const BindingBlock& block) override; // NOLINT(*) + // specific leaf level visitor functions + virtual BindingBlock VisitBindingBlock_(const BindingBlockNode* block); + virtual BindingBlock VisitBindingBlock_(const DataflowBlockNode* block); + + /*! + * \brief Generic dispatcher for rewriting the var definition site. + * \param var The var to be visited. + * \return The var after post-order rewritten. + * \note VisitExpr_(const VarNode*) will only visit the usage site of an Var + */ + virtual Var VisitVarDef(const Var& var); + // specific leaf level visitor functions + virtual Var VisitVarDef_(const VarNode* var); + virtual Var VisitVarDef_(const DataflowVarNode* var); + + protected: + /*! + * \brief Try to remit binding and bind it to a new_value + * + * This function is called after VisitExpr(binding->value) in + * VisitBinding_(const VarBinding*). + * It will try to reuse the current binding when the new value's shape/type + * matches the original binding and no changes in var is needed. + * + * Otherwise, a new binding will be emitted to replace the var specified in + * the current binding. + */ + void ReEmitBinding(const VarBindingNode* binding, Expr new_value); + + /*! + * \brief Rewrite the expr with a new scope, used in a Function's body and the branches of If. + * + * \param body_expr The body to be visited. + * \param params Optional parameters that are visible within the scope. + * \return The expr after visiting. + * + * \note The body_expr must be an SeqExpr in the normal form. + */ + Expr VisitWithNewScope(const Expr& body_expr, Optional> params = NullOpt); + + /*! + * \brief Look up the value bound to a variable. + * \param var The var to be looked up. + * \return The value bound to the input \p var. + * \note For function parameters, this function returns NullOpt. + */ + Optional LookupBinding(const Var& var); + + /*! + * \brief Post-order rewrite a node and normalize. + * \param T The node type to be rewritten. + * \param op The node to be rewritten. + * \return The node after post rewritten. + */ + template + Expr VisitExprPostOrder_(const T* op) { + return builder_->Normalize(ExprMutator::VisitExpr_(op)); + } + + /*! + * \brief Create a new var with specified struct_info if the original var's shape or type does + * not match with the specified ones. + * \param var The var to be updated. + * \param struct_info The struct info to be updated. + * \return The var filled with struct_info + */ + Var WithStructInfo(Var var, StructInfo struct_info); + + /*! \brief Internal block builder to emit bindings during rewriting. */ + BlockBuilder builder_; + + /*! \brief Remap a var to a new var in use-site. */ + std::unordered_map var_remap_; + + private: + using TSelf = ExprMutator; + using VisitBindingVTable = + tvm::NodeFunctor; + // initialize the vtable. + static VisitBindingVTable InitVisitBindingVTable(); +}; + } // namespace relax } // namespace tvm #endif // TVM_RELAX_EXPR_FUNCTOR_H_ diff --git a/include/tvm/relax/op_attr_types.h b/include/tvm/relax/op_attr_types.h new file mode 100644 index 000000000000..e171a8d47b0d --- /dev/null +++ b/include/tvm/relax/op_attr_types.h @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/op_attr_types.h + * \brief Data structures that can appear in operator attributes. + */ +#ifndef TVM_RELAX_OP_ATTR_TYPES_H_ +#define TVM_RELAX_OP_ATTR_TYPES_H_ + +#include +#include +#include +#include + +#include + +namespace tvm { +namespace relax { + +/*! + * \brief Infer output struct info given the call + * + * \param call The call expression to be derived. + * \param ctx The builder context. + */ +using FInferStructInfo = + runtime::TypedPackedFunc; + +/*! + * \brief Packed function implementation for operators. The relax operator will be lowered to + * this packed function call during codegen. + */ +using FCallPacked = String; + +struct PrintAttrs : public tvm::AttrsNode { + std::string format; + TVM_DECLARE_ATTRS(PrintAttrs, "relax.attrs.PrintAttrs") { + TVM_ATTR_FIELD(format) + .describe("Python-style format string to use for displaying the input. Ignored if empty.") + .set_default(""); + } +}; + +struct AssertOpAttrs : public tvm::AttrsNode { + std::string format; + TVM_DECLARE_ATTRS(AssertOpAttrs, "relax.attrs.AssertOpAttrs") { + TVM_ATTR_FIELD(format) + .describe( + "Python-style format string to use for displaying " + "an error message if the assert fails. " + "Ignored if empty.") + .set_default(""); + } +}; + +} // namespace relax +} // namespace tvm +#endif // TVM_RELAX_OP_ATTR_TYPES_H_ diff --git a/include/tvm/relax/struct_info.h b/include/tvm/relax/struct_info.h index f38a32f6bb83..b9aebc549474 100644 --- a/include/tvm/relax/struct_info.h +++ b/include/tvm/relax/struct_info.h @@ -22,16 +22,13 @@ #include #include #include +#include #include #include namespace tvm { namespace relax { -// TODO(relax-team) replace with real BlockBuilder -// once it is ready. -using BlockBuilder = ObjectRef; - /*! * \brief Opaque object. */ diff --git a/include/tvm/relax/utils.h b/include/tvm/relax/utils.h new file mode 100644 index 000000000000..1457a16427cc --- /dev/null +++ b/include/tvm/relax/utils.h @@ -0,0 +1,155 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/utils.h + * \brief Utility classes and functions for working with the Relax IR. + */ +#ifndef TVM_RELAX_UTILS_H_ +#define TVM_RELAX_UTILS_H_ + +#include +#include + +#include +#include +#include +#include +#include + +namespace tvm { +namespace relax { + +/*! + * \brief Utility data structure for generating unique names for IR construction. + */ +class NameTable { + public: + /*! + * \brief Generate a unique name with a specified prefix. + * \param prefix The name prefix. + * \return The generated name. + */ + inline std::string GetUniqueName(std::string prefix) { + std::replace(prefix.begin(), prefix.end(), '.', '_'); + std::string unique_prefix = prefix; + auto it = alloc_map_.find(prefix); + if (it != alloc_map_.end()) { + while (alloc_map_.count(unique_prefix = prefix + std::to_string(++it->second)) > 0) { + } + } + alloc_map_[unique_prefix] = 0; + return unique_prefix; + } + + NameTable() = default; + + template + explicit NameTable(Iter begin, Iter end, Lambda f) { + // static_assert is more reader-friendly than SFINAE when template specialization is not needed. + static_assert(std::is_convertible::value, + "Lambda f must has a signature of [?](*it) -> string {}"); + for (auto it = begin; it != end; ++it) { + const std::string& name = f(*it); + const size_t idx_last_first_num = std::distance( + std::find_if(name.rbegin(), name.rend(), [](char c) { return !std::isdigit(c); }), + name.rend()); + // name = {O = others}{D = consecutive digits} + // let O -> prefix; + std::string prefix = name.substr(0, idx_last_first_num); + ICHECK(prefix.size() > 0 && std::isalpha(prefix[0])) << "Invalid variable name: " << name; + if (0 == alloc_map_.count(prefix)) alloc_map_[prefix] = 0; + if (idx_last_first_num < name.size()) { // has some digits. + // let D's nearest natural number -> idx; + // note: stoul("000123") = 123; + alloc_map_[prefix] = + std::max(alloc_map_[prefix], std::stoi(name.substr(idx_last_first_num))); + } + } + } + + template + explicit NameTable(Iter begin, Iter end) + : NameTable(begin, end, [](const decltype(*begin)& v) { return v; }) {} + + private: + std::unordered_map alloc_map_; +}; + +/*! + * \brief Bind the variables to a Relax expression. This is a helper + * function usually called by other pass functions to help optimizations. + * If any free variables are introduced into a function, those are added + * to the function parameters. + * Additionally this may change the order of parameters if you map a variable + * to a variable. + * + * \param expr The input expression. + * \param binds The variable to expression map that will be used to help the + * binding. + * + * \return The updated expression. + */ +TVM_DLL Expr Bind(const Expr& expr, const tvm::Map& binds); + +/*! + * \brief Check if the given type is a boolean scalar type (tensor of rank 0 with a boolean dtype). + * + * \param ty The input type. + * \param permit_unknown_rank If true, it will permit the input type to have unknown rank + * (ndim of -1), which will require a dynamic check. + * \param permit_unknown_dtype If true, it will permit the input type to have an unknown dtype + * (namely, void), which will require a dynamic check. + * + * \return True iff the input type is a boolean scalar type (or, depending on options, has unknown + * rank or dtype) + */ +TVM_DLL bool IsBoolScalarType(const Type& ty, bool permit_unknown_rank = true, + bool permit_unknown_dtype = true); + +/*! + * \brief Check if the given expression is a "leaf" node or tuple node for normalization purposes. + * + * The following expressions are defined as leaf nodes: Var, Constant, ShapeExpr, + * GlobalVar, Op, ExternFunc. + * + * Tuples are included in this list mainly for convenience in grouping operator arguments. + * *Note*: Since tuples can contain nested expressions, it is necessary to ensure that + * values nested inside them are also leaves. + * + * \param expr The input expression + * + * \return True iff the input expression is a "leaf" node (a value allowed to appear + * inline without being bound to a var during normalization). + */ +TVM_DLL bool IsLeafOrTuple(const Expr& expr); + +/*! + * \brief Copy the given function. The parameters of the original function would be copied to + * satisfy the restriction in the well-formed check: any two functions cannot share the same + * parameter variable. + * \param func The relax function to copy. + * \return The copied function. + */ +TVM_DLL Function CopyWithNewParams(Function func); + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_UTILS_H_ diff --git a/include/tvm/te/operation.h b/include/tvm/te/operation.h index 2c50f3c3157b..f5753afa560f 100644 --- a/include/tvm/te/operation.h +++ b/include/tvm/te/operation.h @@ -182,7 +182,7 @@ class PlaceholderOpNode : public OperationNode { } static constexpr const char* _type_key = "PlaceholderOp"; - TVM_DECLARE_FINAL_OBJECT_INFO(PlaceholderOpNode, OperationNode); + TVM_DECLARE_BASE_OBJECT_INFO(PlaceholderOpNode, OperationNode); }; /*! diff --git a/python/tvm/ir/function.py b/python/tvm/ir/function.py index c3f1bf5f562a..d02698edb54d 100644 --- a/python/tvm/ir/function.py +++ b/python/tvm/ir/function.py @@ -66,3 +66,19 @@ def with_attr(self, attr_key_or_dict, attr_value=None): return _ffi_api.BaseFuncWithAttr( res._move(), attr_key_or_dict, tvm.runtime.convert(attr_value) ) + + def without_attr(self, attr_key: str) -> "BaseFunc": + """Create a new copy of the function with an attribute without provided key. + + Parameters + ---------- + attr_key : str + The attribute key to delete from the attrubte pairs. + + + Returns + ------- + func : BaseFunc + A new copy of the function + """ + return _ffi_api.BaseFuncWithoutAttr(self, attr_key) diff --git a/python/tvm/meta_schedule/utils.py b/python/tvm/meta_schedule/utils.py index 401fdab08a26..9132402b4c9a 100644 --- a/python/tvm/meta_schedule/utils.py +++ b/python/tvm/meta_schedule/utils.py @@ -75,14 +75,27 @@ def _extract(inst: type, name: str): def method(*args, **kwargs): return getattr(inst, name)(*args, **kwargs) - if getattr(base, name) is getattr(cls, name) and name != "__str__": - # for task scheduler return None means calling default function - # otherwise it will trigger a TVMError of method not implemented - # on the c++ side when you call the method, __str__ not required - return None - return method + for inherit_cls, base_cls in zip(cls.__mro__, cls.__mro__[1:]): + # extract functions that differ from the base class + if not hasattr(base_cls, name): + continue + if getattr(base_cls, name) is getattr(inherit_cls, name) and name != "__str__": + continue + return method + + # for task scheduler return None means calling default function + # otherwise it will trigger a TVMError of method not implemented + # on the c++ side when you call the method, __str__ not required + return None assert isinstance(cls.__base__, type) + if hasattr(cls, "_type") and cls._type == "TVMDerivedObject": + raise TypeError( + ( + f"Inheritance from a decorated object `{cls.__name__}` is not allowed. " + f"Please inherit from `{cls.__name__}._cls`." + ) + ) assert hasattr( cls, "_tvm_metadata" ), "Please use the user-facing method overriding class, i.e., PyRunner." @@ -95,6 +108,9 @@ def method(*args, **kwargs): class TVMDerivedObject(metadata["cls"]): # type: ignore """The derived object to avoid cyclic dependency.""" + _cls = cls + _type = "TVMDerivedObject" + def __init__(self, *args, **kwargs): """Constructor.""" self.handle = None @@ -111,12 +127,22 @@ def __init__(self, *args, **kwargs): # using weakref to avoid cyclic dependency self._inst._outer = weakref.ref(self) - def __getattr__(self, name: str): - """Bridge the attribute function.""" - try: - return self._inst.__getattribute__(name) - except AttributeError: - return super(TVMDerivedObject, self).__getattr__(name) + def __getattr__(self, name): + # fall back to instance attribute if there is not any + # return self._inst.__getattribute__(name) + import inspect # pylint: disable=import-outside-toplevel + + result = self._inst.__getattribute__(name) + if inspect.ismethod(result): + + def method(*args, **kwargs): + return result(*args, **kwargs) + + # set __own__ to aviod implicit deconstruction + setattr(method, "__own__", self) + return method + + return result def __setattr__(self, name, value): if name not in ["_inst", "key", "handle"]: diff --git a/python/tvm/relax/__init__.py b/python/tvm/relax/__init__.py index 01310f6455dd..ce175354d02c 100644 --- a/python/tvm/relax/__init__.py +++ b/python/tvm/relax/__init__.py @@ -21,6 +21,8 @@ from . import ty from . import analysis from . import vm +from . import block_builder +from . import op from . import struct_info # Expr @@ -60,6 +62,15 @@ from .exec_builder import ExecBuilder from .vm import VirtualMachine +# Operator +from .op.base import call_tir + +# BlockBuilder +from .block_builder import BlockBuilder + +# ExprFunctor +from .expr_functor import ExprFunctor, PyExprVisitor, PyExprMutator + # StructInfo from .struct_info import ( StructInfo, diff --git a/python/tvm/relax/analysis/analysis.py b/python/tvm/relax/analysis/analysis.py index 301f3ecc7265..d81c477145ec 100644 --- a/python/tvm/relax/analysis/analysis.py +++ b/python/tvm/relax/analysis/analysis.py @@ -26,8 +26,8 @@ from tvm import tir from tvm.relax.ty import Type -from tvm.relax.struct_info import StructInfo -from tvm.relax.expr import Var, Expr +from tvm.relax.struct_info import StructInfo, FuncStructInfo +from tvm.relax.expr import Var, Expr, Call from . import _ffi_api @@ -116,6 +116,35 @@ def struct_info_base_check(base: StructInfo, derived: StructInfo) -> BaseCheckRe return _ffi_api.StructInfoBaseCheck(base, derived) # type: ignore +def derive_call_ret_struct_info( + func_sinfo: FuncStructInfo, call: Call, ctx: "tvm.relax.BlockBuilder" +) -> StructInfo: + """Derive the call's ret value struct info from inputs. + + Parameters + ---------- + func_sinfo: FuncStructInfo + The call's function signature. + + call: Call + The call expression + + ctx: tvm.relax.BlockBuilder + The context block builder. + + Returns + ------- + ret : StructInfo + The derived return value struct info. + + Note + ---- + This is an internal derivation function, call.op field is + ignored in this case and the derivation only depends on func_sinfo. + """ + return _ffi_api.DeriveCallRetStructInfo(func_sinfo, call, ctx) # type: ignore + + def struct_info_lca(lhs: StructInfo, rhs: StructInfo) -> StructInfo: """Unify the two struct info to their least common ancestor. diff --git a/python/tvm/relax/block_builder.py b/python/tvm/relax/block_builder.py new file mode 100644 index 000000000000..77b45fdf5519 --- /dev/null +++ b/python/tvm/relax/block_builder.py @@ -0,0 +1,801 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=no-else-return, invalid-name +"""Developer API of constructing Relax AST.""" +import typing + +from typing import Dict, List, Optional, Union, Any, Callable +from tvm.ir.module import IRModule +from tvm.runtime import Object +from tvm import relax as rx, tir +import tvm +from .expr import ( + Expr, + te_tensor, + Var, + ShapeExpr, + GlobalVar, + BindingBlock, + Tuple, + BaseFunc, + Binding, +) +from .struct_info import ShapeStructInfo, StructInfo, TensorStructInfo +from .op.base import call_tir +from . import _ffi_api + + +class FunctionScope(object): + """Auxiliary scope for function""" + + def __init__(self, block_builder, name, params, attrs): + self._bb = block_builder + self._name = name + self._params = params + self._attrs = attrs + + def __enter__(self): + self._bb._enter_function_scope(self._name, self._params, self._attrs) + + def __exit__(self, exc_type, exc_val, exc_tb): + # __exit__ should properly handle the case where the with block exits with an exception + # when handling error case in exit, always check if there is already an exception + # been thrown in the with block + self._bb._exit_function_scope(exc_type, exc_val, exc_tb) + + +class DataflowScope(object): + """Auxiliary scope for Dataflow block""" + + def __init__(self, block_builder): + self._bb = block_builder + + def __enter__(self): + block = self._bb._end_block() + if len(block.bindings) > 0: + self._bb._blocks.append(block) + self._bb._begin_dataflow_block() + + def __exit__(self, ptype, value, trace): + block = self._bb._end_block() + if len(block.bindings) > 0: + self._bb._blocks.append(block) + self._bb._begin_binding_block() + + +class TestingScope(object): + """Auxiliary scope for testing purposes""" + + def __init__(self, block_builder, def_vars): + self._bb = block_builder + shape_vars = [] + for var in def_vars: + if isinstance(var, tvm.tir.Var): + shape_vars.append(var) + else: + raise ValueError("def_vars only can take tir.Var") + # setup a dummy var so shape is in scope. + sparam = rx.Var("sparam", rx.ShapeStructInfo(shape_vars)) + self._scope_params = [sparam] + + def __enter__(self): + self._bb.begin_scope(self._scope_params) + self._bb._begin_dataflow_block() + + def __exit__(self, ptype, value, trace): + self._bb._end_block() + self._bb.end_scope() + + +@tvm._ffi.register_object("relax.BlockBuilder") +class BlockBuilder(Object): + """A builder to build Relax IR for testing and dev. + + Examples + -------- + .. code-block:: python + + m = tir.Var("m", "int32") + n = tir.Var("n", "int32") + x = rx.Var("x", rx.TensorStructInfo([m, n], "float16")) + y = rx.Var("y", rx.TensorStructInfo([n], "float16") + bb = rx.BlockBuilder() + with bb.function([x, y], "func"): + with bb.dataflow() as df: + lv0 = bb.emit(rx.add(x, y)) + lv1 = bb.emit(rx.multiply(lv0, y)) + gv0 = bb.emit_output(lv1) + bb.emit_func_output(gv0) + mod = bb.get() + + BlockBuilder can also be used to construct neural networks with nn.Module API + + .. code-block:: python + + from tvm.relax.testing import nn + + n = tir.Var("n", "int64") + input_size = 784 + hidden_sizes = [128, 32] + output_size = 10 + bb = rx.BlockBuilder() + + with bb.function("main"): + model = nn.Sequential( + nn.Linear(input_size, hidden_sizes[0]), + nn.ReLU(), + nn.Linear(hidden_sizes[0], hidden_sizes[1]), + nn.ReLU(), + nn.Linear(hidden_sizes[1], output_size), + nn.LogSoftmax(), + ) + data = nn.Placeholder((n, input_size), name="data") + output = model(data) + params = [data] + model.parameters() + builder.emit_func_output(output, params=params) + mod = bb.get() + """ + + _current = None + + @staticmethod + def current(): + """Returns the current BlockBuilder.""" + return BlockBuilder._current + + def __init__(self, mod: IRModule = None): + self._blocks: List[BindingBlock] = [] + # a boolean flag that tracks if emit_func_output has been called + self._is_emit_func_output_called = False + self.__init_handle_by_constructor__(_ffi_api.BlockBuilderCreate, mod) # type: ignore + + def _begin_dataflow_block(self) -> None: + _ffi_api.BlockBuilderBeginDataflowBlock(self) # type: ignore + + def _begin_binding_block(self) -> None: + _ffi_api.BlockBuilderBeginBindingBlock(self) # type: ignore + + def _end_block(self) -> BindingBlock: + return _ffi_api.BlockBuilderEndBlock(self) # type: ignore + + def _enter_function_scope(self, name, params, attrs): + if BlockBuilder.current() is not None: + raise RuntimeError("BlockBuilder does not allow nested functions.") + BlockBuilder._current = self + self._func_name = name + self._func_params = params + self._func_attrs = attrs + self.begin_scope(params) + self._begin_binding_block() + + def _exit_function_scope(self, exc_type, exc_val, exc_tb): + # record + is_emit_func_output_called = self._is_emit_func_output_called + # recover to default state + self._blocks = [] + self._is_emit_func_output_called = False + BlockBuilder._current = None + + # NOTE: we must raise after we recover the state so future + # block builder scoping functions correctly + if exc_type is None: + if not is_emit_func_output_called: + raise RuntimeError("emit_func_output must be called in a relax function.") + + def _convert_te_arg( + self, te_args: Any, tir_var_map: Dict[tir.Var, tir.PrimExpr] + ) -> typing.Tuple[Any, List[tvm.te.Tensor]]: + """Helper function used by `call_te` to convert Relax expressions to TE tensor. + + In the common case, the type of te_args is a Relax expression and is converted + into a TE tensor. + If te_args is a nested or recursive datatype (i.e list, dict, tvm.ir.Map, tvm.ir.Array), + we recursive and convert any value of type Relax expression into a TE tensor. + Common values of type int, float, and str are preserved. + + In dynamic shape cases, the passed in arguments may contain TIR variable. + For example, the argument can be a Relax Var with TensorStructInfo, which + has symbolic shape, or the argument can be a ShapeExpr with symbolic variables. + To make the PrimFunc generated by `call_te` has independent variables with + the caller Relax function, we will substitute the TIR variables in the input + arguments with fresh ones, which is done by maintaining a TIR variable mapping. + + Parameters + ---------- + te_args : Any + Argument to convert to TE + + tir_var_map : Dict[tir.Var, tir.PrimExpr] + The TIR variable mapping, which maps TIR variables on the Relax function + side to the new set of variables used on the PrimFunc side. + + Returns + ------- + ret : (Any, [tvm.te.Tensor]) + A tuple of the converted te_args, and a list of te tensors for each converted + Relax expression + """ + te_args_list = [] + + def _copy_undefined_var(expr: tir.PrimExpr): + def _visit_expr(e: tir.PrimExpr): + if isinstance(e, tir.Var) and e not in tir_var_map: + new_var = tir.Var(e.name, e.dtype) + tir_var_map[e] = new_var + + tir.stmt_functor.post_order_visit(expr, _visit_expr) + + def _convert_te_arg_helper(arg): + if isinstance(arg, Expr): # type: ignore + if isinstance(arg.struct_info, TensorStructInfo): + assert isinstance( + arg.struct_info.shape, ShapeExpr + ), "emit_te now only supports Tensor that has ShapeExpr shape" + for shape_value in arg.struct_info.shape.values: + _copy_undefined_var(shape_value) + + arg = te_tensor(arg, tir_var_map) + te_args_list.append(arg) + return arg + elif isinstance(arg.struct_info, ShapeStructInfo): + assert isinstance( + arg, ShapeExpr + ), "For Expr having ShapeStructInfo, emit_te now only supports ShapeExpr" + return [_convert_te_arg_helper(val) for val in arg.values] + elif isinstance(arg, (list, tvm.ir.Array)): + return [_convert_te_arg_helper(x) for x in arg] + elif isinstance(arg, tuple): + return tuple([_convert_te_arg_helper(x) for x in arg]) + elif isinstance(arg, (dict, tvm.ir.Map)): + for key in arg: + assert isinstance( + key, str + ), "emit_te only supports dict with string as the key currently" + return {k: _convert_te_arg_helper(arg[k]) for k in arg} + elif isinstance(arg, tir.PrimExpr): + _copy_undefined_var(arg) + return tir.stmt_functor.substitute(arg, tir_var_map) + elif isinstance(arg, (int, float, str, tvm.ir.Type, tvm.ir.Attrs)) or arg is None: + return arg + raise TypeError("not supported type in emit_te: {}".format(type(arg))) + + new_arg = _convert_te_arg_helper(te_args) + return new_arg, te_args_list + + def _get_unbound_tir_vars(self, args: List[tvm.te.Tensor]) -> List[tvm.tir.Var]: + """get unbound TIR vars (i.e TIR vars used in the shape but is not + itself a dimension of a shape)""" + bound_vars = set() + used_vars = set() + + def _populate_used_vars(expr): + if isinstance(expr, tvm.tir.Var): + used_vars.add(expr) + + for x in args: + for s in x.shape: + tvm.tir.stmt_functor.post_order_visit(s, _populate_used_vars) + if isinstance(s, tir.Var): + bound_vars.add(s) + + diff = used_vars - bound_vars + return list(diff) + + def function( + self, + name: str, + params: Optional[Union[Var, Tuple, List[Var]]] = None, + attrs: Optional[Dict[str, Object]] = None, + ) -> FunctionScope: + """Annotate a Relax function. + + Parameters + ---------- + name : str, optional + The name of the function + + params : tvm.relax.Var | Tuple | List[tvm.relax.Var], optional + The parameters of the function. + If params is None, it means deferring initialization of function parameters + until emit_func_output. + + attrs : Dict[str, Object], optional + The function attrs + + Returns + ------- + ret: FunctionScope + A FunctionScope for building a Relax function node. + """ + if not params: + params = None + elif isinstance(params, rx.Var): + params = [params] + elif isinstance(params, (list, tuple)): + for param in params: + if not isinstance(param, rx.Var): + raise TypeError( + "each element of function parameters must be of type tvm.relax.Var,\ + but got: {}".format( + type(param) + ) + ) + if attrs is None: + attrs = {} + return FunctionScope(self, name, params, attrs) + + def testing_scope(self, def_vars: List[tir.Var]) -> TestingScope: + """Start a scope for unit-testing purposes. + + Parameters + ---------- + def_vars: List[tir.Var] + List of symbolic variables that are marked as defined in scope. + + Returns + ------- + ret: TestingScope + A TestingScope to setup builder for emit and other purposes. + """ + return TestingScope(self, def_vars) + + def dataflow(self) -> DataflowScope: + """Annotate a Relax dataflow block. + + Returns + ------- + ret: DataflowScope + A DataflowScope for building a Relax dataflow block. + """ + return DataflowScope(self) + + def emit(self, expr: Expr) -> Var: + """Emit an expr. + This infers the shape and type of the expr, create a variable, + and bind the expr to the variable. + + Parameters + ---------- + expr : tvm.relax.Expr + The Expr to be emitted. + + Returns + ------- + ret : tvm.relax.Var + A newly created variable that gets bound to the input expr. + """ + return _ffi_api.BlockBuilderEmit(self, expr) # type: ignore + + def call_te(self, func: Callable, *args: Any, **kwargs: Any) -> Expr: + """Generate a call node according to the te function. + This function converts arguments from relax expression to te tensor, + The callback func should return a te tensor or a list of te tensors. + Please see detailed example in emit_te + + Parameters + ---------- + func : Callable + A function that returns a te tensor or a list of te tensors. + + args : Any, optional + arguments passed to the function. + + kwargs : Any, optional + The keyword arguments passed to the function. + Note that the key "primfunc_name_hint" is reserved for passing name hint + to the PrimFunc that gets generated. + + Returns + ------- + ret : tvm.relax.Call + A newly created call node + """ + + primfunc_name_hint = kwargs.pop("primfunc_name_hint", None) + tir_var_map: Dict[tir.Var, tir.PrimExpr] = dict() + new_args, te_arg_list = self._convert_te_arg(args, tir_var_map) + new_kwargs, te_kwarg_list = self._convert_te_arg(kwargs, tir_var_map) + + te_args = te_arg_list + te_kwarg_list + + te_out = func(*new_args, **new_kwargs) + assert isinstance(te_out, tvm.te.tensor.Tensor) or ( + isinstance(te_out, (tuple, list, tvm.ir.Array)) + and all(isinstance(t, tvm.te.tensor.Tensor) for t in te_out) + ), "only support te.tensor or tuple/list/Array of te.tensor as function output" + + outs = [te_out] if isinstance(te_out, tvm.te.tensor.Tensor) else list(te_out) + unbound_tir_vars = self._get_unbound_tir_vars(te_args + outs) + + inputs = [*te_args] + outs + tir_func = tvm.te.create_relax_prim_func(inputs, unbound_tir_vars, "int64") + + tir_func = tir_func.without_attr("global_symbol") + + if primfunc_name_hint: + gvar = self.add_func(tir_func, primfunc_name_hint) + else: + gvar = self.add_func(tir_func, func.__name__) + + call_args = [x.op.value for x in te_args] + + def _shape_with_old_tir_var( + shape_values: List[tir.PrimExpr], tir_var_inverse_map: Dict[tir.Var, tir.PrimExpr] + ): + return ShapeExpr( + [tir.stmt_functor.substitute(value, tir_var_inverse_map) for value in shape_values] + ) + + # Invert the TIR variable mapping, to convert the output shape back + # with old set of variables. + tir_var_inverse_map = {v: k for k, v in tir_var_map.items()} + + output_sinfo = [ + TensorStructInfo(_shape_with_old_tir_var(out.shape, tir_var_inverse_map), out.dtype) + for out in outs + ] + + # add arguments for extra parameters from unbound var + if len(unbound_tir_vars) > 0: + call = call_tir( + gvar, + call_args, + output_sinfo, + tir_vars=_shape_with_old_tir_var(unbound_tir_vars, tir_var_inverse_map), + ) + else: + call = call_tir(gvar, call_args, output_sinfo) + return call + + def emit_te(self, func: Callable, *args: Any, **kwargs: Any) -> Var: + """Emit a call node according to the te function. + This function converts arguments from relax expression to te tensor, + The callback func should return a te tensor or a list of te tensors. + + Parameters + ---------- + func : Callable + A function that returns a te tensor or a list of te tensors. + + args : Any, optional + arguments passed to the function. + + kwargs : Any, optional + The keyword arguments passed to the function. + Note that the key "primfunc_name_hint" is reserved for passing name hint + to the PrimFunc that gets generated. + + Returns + ------- + ret : tvm.relax.Var + A newly created variable that gets bound to the call code. + + Example + ------- + + .. code-block:: python + + bb = rx.BlockBuilder() + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + x = rx.Var("x", rx.TensorStructInfo([n, m], "float32")) + y = rx.Var("y", rx.TensorStructInfo([n, m], "float32")) + + def te_func(args, args_dict, msg): + A = args[0] + B = args_dict["B"] + return te.compute((128, 128), lambda i, j: A[i, j] + B[i, j]) + + with bb.function([x, y], "rx_func"): + out = bb.emit_te(te_func, [x], {"B": y}, msg="hello") + bb.emit_func_output(out) + + will result in TVMScript + + .. code-block:: python + + @tvm.script.ir_module + class Module: + @T.prim_func + def te_func(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, + var_compute: T.handle) -> None: + # function attr dict + T.func_attr({"tir.noalias": True}) + m = T.var("int64") + n = T.var("int64") + rxplaceholder = T.match_buffer(var_rxplaceholder, [n, m], dtype="float32") + rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [n, m], dtype="float32") + compute = T.match_buffer(var_compute, [128, 128], dtype="float32") + # body + # with T.block("root") + for i0, i1 in T.grid(128, 128): + with T.block("compute"): + i, j = T.axis.remap("SS", [i0, i1]) + T.reads([rxplaceholder[i, j], rxplaceholder_1[i, j]]) + T.writes([compute[i, j]]) + compute[i, j] = rxplaceholder[i, j] + rxplaceholder_1[i, j] + + @R.function + def rx_func(x: Tensor((n, m), "float32"), y: Tensor((n, m), "float32")) -> Tensor: + # block 0 + gv = relax.call_tir("te_func", (x, y), R.Tensor((128, 128), "float32")) + return gv + + Example + ------- + + .. code-block:: python + + bb = relax.BlockBuilder() + n = tir.Var("n", "int64") + x = relax.Var("x", relax.TensorStructInfo([n], "float32")) + y = relax.Var("y", relax.TensorStructInfo([n + 1], "float32")) + + def te_func(A): + C = te.compute((n + 1), lambda i: A[i]) + return C + + with bb.function("rx_func", [x, y]): + x1 = bb.emit_te(te_func, y) + bb.emit_func_output(x1) + + will result in TVMScript + + .. code-block:: python + + @tvm.script.ir_module + class Module: + @T.prim_func + def te_func(var_rxplaceholder: T.handle, var_compute: T.handle, n: T.int64) -> None: + rxplaceholder = T.match_buffer(var_rxplaceholder, [n + T.int64(1)], + dtype="float32") + compute = T.match_buffer(var_compute, [n + T.int64(1)], dtype="float32") + # body + # with T.block("root") + for i0 in T.serial(0, n + T.int64(1)): + with T.block("compute"): + i = T.axis.spatial(n + T.int64(1), i0) + T.reads([rxplaceholder[i]]) + T.writes([compute[i]]) + compute[i] = rxplaceholder[i] + + @R.function + def rx_func(x: Tensor((n,), "float32"), y: Tensor(((n + 1),), "float32")) + -> Tensor(None, "float32", ndim=-1): + # block 0 + gv = relax.call_tir(te_func, (y,), R.Tensor((n + 1,), "float32"), (n,)) + return gv + """ + return self.emit(self.call_te(func, *args, **kwargs)) + + def match_cast(self, value: Expr, struct_info: StructInfo) -> Var: + """Emit a MatchCast. + + Parameters + ---------- + value : tvm.relax.Expr + The value of the MatchCast to be emitted. + + struct_info : StructInfo + The struct info to be matched. + + Returns + ------- + ret : tvm.relax.Var + A newly created variable that get bounds to be the casted result. + """ + return _ffi_api.BlockBuilderEmitMatchCast(self, value, struct_info) # type: ignore + + def emit_output(self, output: Union[Expr, Tuple, List[Expr]]) -> None: + """Emit output for the current dataflow block or function. + + Parameters + ---------- + output : Expr | Tuple | List[Expr] + The output of the current block/function. + + Returns + ------- + ret : tvm.relax.Var + The return variable which gets bound to the output. + """ + if isinstance(output, (list, tuple)): + output = Tuple(output) + return _ffi_api.BlockBuilderEmitOutput(self, output) # type: ignore + + def emit_func_output( + self, + output: Union[Expr, Tuple, List[Expr]], + params: Optional[Union[Var, Tuple, List[Var]]] = None, + ) -> None: + """Emit output for the function. + + Parameters + ---------- + output : Expr | Tuple | List[Expr] + The output of the current block/function. + + params : tvm.relax.Var | Tuple | List[tvm.relax.Var], optional + The parameters of the function to be built. + If params is None, it means the params have been initialized in the function with scope. + + Returns + ------- + ret : tvm.relax.Var + The return variable which gets bound to the output. + """ + if self._is_emit_func_output_called: + raise RuntimeError("emit_func_output must be called exactly once in a relax function.") + self._is_emit_func_output_called = True + + if self._func_params is not None and params is not None: + raise RuntimeError( + "function parameters have been initialized in the function with scope." + ) + + if self._func_params is None and params is None: + raise RuntimeError("Relax function must have parameter.") + + if self._func_params is None: + self._func_params = params + + if BlockBuilder.current() is not self: + raise RuntimeError("BlockBuilder._current must be self.") + + if isinstance(output, (list, tuple)): + output = Tuple(output) + + block = self._end_block() + if len(block.bindings) > 0: + self._blocks.append(block) + seqe = self.normalize(rx.SeqExpr(self._blocks, output)) + + # do not specify ret_struct_info and let constructor deduce + # from seqe.struct_info + func = rx.Function(self._func_params, seqe) + for key, value in self._func_attrs.items(): + func = func.with_attr(key, value) + self.end_scope() + self.add_func(func, self._func_name) + + def normalize(self, expr: Expr) -> Expr: + """Normalize an Expr to complete its shape and type. + + Parameters + ---------- + expr : Expr + The input expr. + + Returns + ------- + ret : Expr + The expr with normalized shape and type. + """ + return _ffi_api.BlockBuilderNormalize(self, expr) # type: ignore + + def get(self) -> tvm.IRModule: + """Return the IRModule being built. + + Returns + ------- + ret : tvm.IRModule + An IRModule with Relax and TIR functions being built. + """ + return _ffi_api.BlockBuilderGetContextIRModule(self) # type: ignore + + def get_unique_name(self, name_prefix: str) -> str: + """Generate a unique name with a specified prefix. + + Parameters + ---------- + name_hint : str + The name prefix. + + Returns + ------- + ret : str + The generated name. + """ + return _ffi_api.BlockBuilderGetUniqueName(self, name_prefix) # type: ignore + + def add_func(self, func: BaseFunc, func_name: str) -> GlobalVar: + """Add a Relax function or a TIR PrimFunc to the IRModule being built. + + Parameters + ---------- + func : BaseFunc + The function to be added. + + func_name : str + The name of the function to be added. + + Returns + ------- + gvar : GlobalVar + The global var bound to the added function. + """ + return _ffi_api.BlockBuilderAddFunction(self, func, func_name) # type: ignore + + def update_func(self, gv: GlobalVar, updated_func: BaseFunc) -> None: + """Add a Relax function or a TIR PrimFunc to the IRModule being built. + + Parameters + ---------- + gv : GlobalVar + The global var referring the function to be updated. + + updated_func : BaseFunc + The updated function. + """ + return _ffi_api.BlockBuilderUpdateFunction(self, gv, updated_func) # type: ignore + + def current_block_is_dataflow(self) -> bool: + """Check if the block being built is DataflowBlock or not. + + Returns + ------- + ret : bool + A boolean that indicates if the block being built is DataflowBlock or not. + """ + return _ffi_api.BlockBuilderCurrentBlockIsDataFlow(self) # type: ignore + + def emit_normalized(self, binding: Binding) -> None: + """Emit an already normalized binding. + + Parameters + ---------- + binding: Binding + The binding to be emitted. + """ + _ffi_api.BlockBuilderEmitNormalized(self, binding) # type: ignore + + def lookup_binding(self, var: Var) -> Optional[Expr]: + """Lookup a var in the binding table binding_table_. + + Parameters + ---------- + var: Var + The input var. + + Returns + ------- + expr: Expr + The Expr bound to the input var. + """ + return _ffi_api.BlockBuilderLookupBinding(self, var) # type: ignore + + def begin_scope(self, params: Optional[List[Var]] = None) -> None: + """Begin a new scope, with optional parameters that + are visible within the scope. + + Parameters + ---------- + params: Optional[List[Var]] + Parameters that are visible within the scope. + + Note + ---- + This function should be called when new scope is introduced + (function, seq) to properly track the variable availability + and help the best effort deduction. + """ + + return _ffi_api.BlockBuilderBeginScope(self, params) # type: ignore + + def end_scope(self) -> None: + """End the current scope. Please see `begin_scope` for details""" + + return _ffi_api.BlockBuilderEndScope(self) # type: ignore diff --git a/python/tvm/relax/expr_functor.py b/python/tvm/relax/expr_functor.py new file mode 100644 index 000000000000..0252720f6ee8 --- /dev/null +++ b/python/tvm/relax/expr_functor.py @@ -0,0 +1,1530 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name, arguments-differ +"""The expression functor of Relax.""" +from typing import Callable, Optional + +import tvm +from tvm.ir import Op +from tvm.meta_schedule.utils import derived_object +from tvm.runtime import Object + +from ..ir.module import IRModule +from . import _ffi_api +from .block_builder import BlockBuilder +from .expr import ( + Binding, + BindingBlock, + Call, + Constant, + Id, + DataflowBlock, + DataflowVar, + DataTypeImm, + Expr, + ExternFunc, + Function, + GlobalVar, + If, + MatchCast, + PrimValue, + SeqExpr, + ShapeExpr, + Span, + StringImm, + Tuple, + TupleGetItem, + Var, + VarBinding, +) +from .struct_info import StructInfo + +visitor = derived_object +""" +A decorator to wrap user-customized PyExprVisitor as TVM object _PyExprVisitor. + +Parameters +---------- +visitor_cls : PyExprVisitor + The user-customized PyExprVisitor. + +Returns +------- +cls : _PyExprVisitor + The decorated TVM object _PyExprVisitor(ExprVisitor on the C++ side). + +Example +------- +.. code-block:: python + + @relax.expr_functor.visitor + class MyExprVisitor(PyExprVisitor): + # customize visit function + def visit_call_(self, op: Call) -> None: + # just for demo purposes + ... + # myvisitor is now a special visitor that visit every Call with + # user-customized visit_call_ + myvisitor = MyExprVisitor() + # apply myvisitor to Expr/Binding/BindingBlock/VarDef + myvisitor.visit_expr(expr) + myvisitor.visit_binding(binding) + myvisitor.visit_binding_block(bindingblock) + myvisitor.visit_var_def(var) +""" + +mutator = derived_object +""" +A decorator to wrap user-customized PyExprMutator as TVM object _PyExprMutator. +Note: Cannot override visit function and post-order rewrite at the same time. + +Parameters +---------- +mutator_cls : PyExprMutator + The user-customized PyExprMutator. + +Returns +------- +cls : _PyExprMutator + The decorated TVM object _PyExprMutator(ExprMutator on the C++ side). + +Example +------- +.. code-block:: python + + @relax.expr_functor.mutator + class MyExprMutator(PyExprMutator): + # customize rewrite function + def visit_tuple_(self, op: Tuple) -> Expr: + # just for demo purposes + ... + + # mymutator is now a special mutator that rewrite every Tuple with + # user-customized visit_tuple_ + mymutator = MyExprMutator() + # apply mymutator to Expr/Binding/BindingBlock/VarDef + mymutator.visit_expr(expr) + mymutator.visit_binding(binding) + mymutator.visit_binding_block(bindingblock) + mymutator.visit_var_def(var) +""" + + +class ExprFunctor: + """ + An abstract visitor defined over Expr. + Defines the default dispatch over expressions, and + implements memoization. + """ + + def visit_expr(self, expr: Expr) -> Expr: + """Apply the visitor to an expression.""" + if isinstance(expr, Constant): # type: ignore + ret = self.visit_constant_(expr) + elif isinstance(expr, Tuple): + ret = self.visit_tuple_(expr) + elif isinstance(expr, DataflowVar): + ret = self.visit_dataflow_var_(expr) + elif isinstance(expr, Var): + ret = self.visit_var_(expr) + elif isinstance(expr, ShapeExpr): + ret = self.visit_shape_expr_(expr) + elif isinstance(expr, ExternFunc): + ret = self.visit_extern_func_(expr) + elif isinstance(expr, GlobalVar): # type: ignore + ret = self.visit_global_var_(expr) + elif isinstance(expr, Function): + ret = self.visit_function_(expr) + elif isinstance(expr, Call): # type: ignore + ret = self.visit_call_(expr) + elif isinstance(expr, SeqExpr): + ret = self.visit_seq_expr_(expr) + elif isinstance(expr, If): # type: ignore + ret = self.visit_if_(expr) + elif isinstance(expr, Op): + ret = self.visit_op_(expr) + elif isinstance(expr, TupleGetItem): + ret = self.visit_tuple_getitem_(expr) + elif isinstance(expr, PrimValue): + ret = self.visit_prim_value_(expr) + elif isinstance(expr, StringImm): + ret = self.visit_string_imm_(expr) + elif isinstance(expr, DataTypeImm): + ret = self.visit_data_type_imm_(expr) + else: + raise TypeError("Invalid type: {0}".format(type(expr))) + + return ret + + def visit_constant_(self, op: Constant): + raise NotImplementedError() + + def visit_tuple_(self, op: Tuple): + raise NotImplementedError() + + def visit_dataflow_var_(self, op: DataflowVar): + raise NotImplementedError() + + def visit_var_(self, op: Var): + raise NotImplementedError() + + def visit_shape_expr_(self, op: ShapeExpr): + raise NotImplementedError() + + def visit_extern_func_(self, op: ExternFunc): + raise NotImplementedError() + + def visit_global_var_(self, op: GlobalVar): + raise NotImplementedError() + + def visit_function_(self, op: Function): + raise NotImplementedError() + + def visit_call_(self, op: Call): + raise NotImplementedError() + + def visit_seq_expr_(self, op: SeqExpr): + raise NotImplementedError() + + def visit_if_(self, op: If): + raise NotImplementedError() + + def visit_op_(self, op: Op): + raise NotImplementedError() + + def visit_tuple_getitem_(self, op: TupleGetItem): + raise NotImplementedError() + + def visit_prim_value_(self, op: PrimValue): + raise NotImplementedError() + + def visit_string_imm_(self, op: StringImm): + raise NotImplementedError() + + def visit_data_type_imm_(self, op: DataTypeImm): + raise NotImplementedError() + + def visit_var_binding_(self, binding: VarBinding): + raise NotImplementedError() + + def visit_match_cast_(self, binding: MatchCast): + raise NotImplementedError() + + def visit_binding_block_(self, block: BindingBlock): + raise NotImplementedError() + + def visit_dataflow_block_(self, block: DataflowBlock): + raise NotImplementedError() + + def visit_var_def_(self, var: Var): + raise NotImplementedError() + + def visit_dataflow_var_def_(self, var: DataflowVar): + raise NotImplementedError() + + def visit_binding(self, binding: Binding): + if isinstance(binding, MatchCast): + self.visit_match_cast_(binding) + elif isinstance(binding, VarBinding): + self.visit_var_binding_(binding) + else: + raise TypeError("Invalid type: {0}".format(type(binding))) + + def visit_binding_block(self, block: BindingBlock): + if isinstance(block, DataflowBlock): + self.visit_dataflow_block_(block) + elif isinstance(block, BindingBlock): + self.visit_binding_block_(block) + else: + raise TypeError("Invalid type: {0}".format(type(block))) + + def visit_var_def(self, var: Var): + if isinstance(var, DataflowVar): + self.visit_dataflow_var_def_(var) + elif isinstance(var, Var): + self.visit_var_def_(var) + else: + raise TypeError("Invalid type: {0}".format(type(var))) + + +@tvm._ffi.register_object("expr_functor.PyExprVisitor") +class _PyExprVisitor(Object): + """ + A TVM object to support customization of ExprVisitor on the python side. + This is the decorated result returned from visitor decorator. + + WARNING: This is NOT the user facing class for method overwriting inheritance. + + See also: visitor, PyExprVisitor + """ + + def __init__( + self, + f_visit_expr: Callable = None, + f_visit_constant_: Callable = None, + f_visit_tuple_: Callable = None, + f_visit_var_: Callable = None, + f_visit_dataflow_var_: Callable = None, + f_visit_shape_expr_: Callable = None, + f_visit_extern_func_: Callable = None, + f_visit_global_var_: Callable = None, + f_visit_function_: Callable = None, + f_visit_call_: Callable = None, + f_visit_seq_expr_: Callable = None, + f_visit_if_: Callable = None, + f_visit_op_: Callable = None, + f_visit_tuple_getitem_: Callable = None, + f_visit_prim_value_: Callable = None, + f_visit_string_imm_: Callable = None, + f_visit_data_type_imm_: Callable = None, + f_visit_binding: Callable = None, + f_visit_var_binding_: Callable = None, + f_visit_match_cast_: Callable = None, + f_visit_binding_block: Callable = None, + f_visit_binding_block_: Callable = None, + f_visit_dataflow_block_: Callable = None, + f_visit_var_def: Callable = None, + f_visit_var_def_: Callable = None, + f_visit_dataflow_var_def_: Callable = None, + f_visit_span: Callable = None, + ) -> None: + """Constructor.""" + + self.__init_handle_by_constructor__( + _ffi_api.MakePyExprVisitor, # type: ignore + f_visit_expr, + f_visit_constant_, + f_visit_tuple_, + f_visit_var_, + f_visit_dataflow_var_, + f_visit_shape_expr_, + f_visit_extern_func_, + f_visit_global_var_, + f_visit_function_, + f_visit_call_, + f_visit_seq_expr_, + f_visit_if_, + f_visit_op_, + f_visit_tuple_getitem_, + f_visit_prim_value_, + f_visit_string_imm_, + f_visit_data_type_imm_, + f_visit_binding, + f_visit_var_binding_, + f_visit_match_cast_, + f_visit_binding_block, + f_visit_binding_block_, + f_visit_dataflow_block_, + f_visit_var_def, + f_visit_var_def_, + f_visit_dataflow_var_def_, + f_visit_span, + ) + + def visit_expr(self, expr: Expr) -> None: + """Generic dispatcher for Expr. + + Parameters + ---------- + expr : Expr + The expr to be visited. + """ + return _ffi_api.PyExprVisitorVisitExpr(self, expr) # type: ignore + + def visit_binding(self, binding: Binding) -> None: + """Generic dispatcher for Binding. + + Parameters + ---------- + binding : Binding + The binding to be visited. + """ + return _ffi_api.PyExprVisitorVisitBinding(self, binding) # type: ignore + + def visit_binding_block(self, block: BindingBlock) -> None: + """Generic dispatcher for BindingBlock. + + Parameters + ---------- + block : BindingBlock + The block to be visited. + """ + return _ffi_api.PyExprVisitorVisitBindingBlock(self, block) # type: ignore + + def visit_var_def(self, var: Var) -> None: + """Generic dispatcher for visiting the var definition site. + Note that visit_var_() will only visit the usage site of an Var. + + Parameters + ---------- + var : Var + The var to be visited. + """ + return _ffi_api.PyExprVisitorVisitVarDef(self, var) # type: ignore + + +class PyExprVisitor: + """ + An abstract ExprVisitor with customized methods on the python-side. + This is the user facing class for method overwriting inheritance. + _tvm_metadata discribes the class to inherit("cls"), the methods + that users can overwrite("methods"). + + Note: @relax.expr_functor.visitor is required for proper usage of any inherited class. + + See also: visitor, _PyExprVisitor + + Example: + @relax.expr_functor.visitor + def MyExprVisitor(PyExprVisitor): + ... + """ + + _tvm_metadata = { + "cls": _PyExprVisitor, + "methods": [ + "visit_expr", + "visit_constant_", + "visit_tuple_", + "visit_var_", + "visit_dataflow_var_", + "visit_shape_expr_", + "visit_extern_func_", + "visit_global_var_", + "visit_function_", + "visit_call_", + "visit_seq_expr_", + "visit_if_", + "visit_op_", + "visit_tuple_getitem_", + "visit_prim_value_", + "visit_string_imm_", + "visit_data_type_imm_", + "visit_binding", + "visit_var_binding_", + "visit_match_cast_", + "visit_binding_block", + "visit_binding_block_", + "visit_dataflow_block_", + "visit_var_def", + "visit_var_def_", + "visit_dataflow_var_def_", + "visit_span", + ], + } + + def visit_expr(self, expr: Expr) -> None: + """Generic dispatcher for Expr. + Users can customized this function to overwrite VisitExpr(const Expr& expr) on the C++ side. + + Parameters + ---------- + expr : Expr + The expr to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.PyExprVisitorVisitExpr(self._outer(), expr) # type: ignore + + def visit_binding(self, binding: Binding) -> None: + """Generic dispatcher for Binding. + Users can customized this function to overwrite VisitBinding(const Binding& binding) + on the C++ side. + + Parameters + ---------- + binding : Binding + The binding to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.PyExprVisitorVisitBinding(self._outer(), binding) # type: ignore + + def visit_binding_block(self, block: BindingBlock) -> None: + """Generic dispatcher for BindingBlock. + Users can customized this function to overwrite VisitBindingBlock(const BindingBlock& block) + on the C++ side. + + Parameters + ---------- + block : BindingBlock + The block to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.PyExprVisitorVisitBindingBlock(self._outer(), block) # type: ignore + + def visit_var_def(self, var: Var) -> None: + """Generic dispatcher for visiting the var definition site. + Users can customized this function to overwrite VisitVarDef(const Var& var) on the C++ side. + Note that visit_var_() will only visit the usage site of an Var. + + Parameters + ---------- + var : Var + The var to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.PyExprVisitorVisitVarDef(self._outer(), var) # type: ignore + + def visit_constant_(self, op: Constant) -> None: + """Visit Constant. + Users can customized this function to overwrite VisitExpr_(const ConstantNode* op) + on the C++ side. + + Parameters + ---------- + op : Constant + The Constant to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.ExprVisitorVisitExpr(self._outer(), op) # type: ignore + + def visit_tuple_(self, op: Tuple) -> None: + """Visit Tuple. + Users can customized this function to overwrite VisitExpr_(const TupleNode* op) + on the C++ side. + + Parameters + ---------- + op : Tuple + The Tuple to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.ExprVisitorVisitExpr(self._outer(), op) # type: ignore + + def visit_var_(self, op: Var) -> None: + """Visit Var. + Users can customized this function to overwrite VisitExpr_(const VarNode* op) + on the C++ side. + + Parameters + ---------- + op : Var + The Var to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.ExprVisitorVisitExpr(self._outer(), op) # type: ignore + + def visit_dataflow_var_(self, op: DataflowVar) -> None: + """Visit DataflowVar. + Users can customized this function to overwrite VisitExpr_(const DataflowVarNode* op) + on the C++ side. + + Parameters + ---------- + op : DataflowVar + The DataflowVar to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.ExprVisitorVisitExpr(self._outer(), op) # type: ignore + + def visit_shape_expr_(self, op: ShapeExpr) -> None: + """Visit ShapeExpr. + Users can customized this function to overwrite VisitExpr_(const ShapeExprNode* op) + on the C++ side. + + Parameters + ---------- + op : ShapeExpr + The ShapeExpr to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.ExprVisitorVisitExpr(self._outer(), op) # type: ignore + + def visit_extern_func_(self, op: ExternFunc) -> None: + """Visit ExternFunc. + Users can customized this function to overwrite VisitExpr_(const ExternFuncNode* op) + on the C++ side. + + Parameters + ---------- + op : ExternFunc + The ExternFunc to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.ExprVisitorVisitExpr(self._outer(), op) # type: ignore + + def visit_global_var_(self, op: GlobalVar) -> None: + """Visit GlobalVar. + Users can customized this function to overwrite VisitExpr_(const GlobalVarNode* op) + on the C++ side. + + Parameters + ---------- + op : GlobalVar + The GlobalVar to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.ExprVisitorVisitExpr(self._outer(), op) # type: ignore + + def visit_function_(self, op: Function) -> None: + """Visit Function. + Users can customized this function to overwrite VisitExpr_(const FunctionNode* op) + on the C++ side. + + Parameters + ---------- + op : Function + The Function to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.ExprVisitorVisitExpr(self._outer(), op) # type: ignore + + def visit_call_(self, op: Call) -> None: + """Visit Call. + Users can customized this function to overwrite VisitExpr_(const CallNode* op) + on the C++ side. + + Parameters + ---------- + op : Call + The Call to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.ExprVisitorVisitExpr(self._outer(), op) # type: ignore + + def visit_seq_expr_(self, op: SeqExpr) -> None: + """Visit SeqExpr. + Users can customized this function to overwrite VisitExpr_(const SeqExprNode* op) + on the C++ side. + + Parameters + ---------- + op : SeqExpr + The SeqExpr to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.ExprVisitorVisitExpr(self._outer(), op) # type: ignore + + def visit_if_(self, op: If) -> None: + """Visit If. + Users can customized this function to overwrite VisitExpr_(const IfNode* op) + on the C++ side. + + Parameters + ---------- + op : If + The If to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.ExprVisitorVisitExpr(self._outer(), op) # type: ignore + + def visit_op_(self, op: Op) -> None: + """Visit Op. + Users can customized this function to overwrite VisitExpr_(const OpNode* op) + on the C++ side. + + Parameters + ---------- + op : Op + The Op to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.ExprVisitorVisitExpr(self._outer(), op) # type: ignore + + def visit_tuple_getitem_(self, op: TupleGetItem) -> None: + """Visit TupleGetItem. + Users can customized this function to overwrite VisitExpr_(const TupleGetItemNode* op) + on the C++ side. + + Parameters + ---------- + op : TupleGetItem + The TupleGetItem to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.ExprVisitorVisitExpr(self._outer(), op) # type: ignore + + def visit_prim_value_(self, op: PrimValue) -> None: + """Visit PrimValue. + Users can customized this function to overwrite VisitExpr_(const PrimValueNode* op) + on the C++ side. + + Parameters + ---------- + op : PrimValue + The PrimValue to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.ExprVisitorVisitExpr(self._outer(), op) # type: ignore + + def visit_string_imm_(self, op: StringImm) -> None: + """Visit StringImm. + Users can customized this function to overwrite VisitExpr_(const StringImmNode* op) + on the C++ side. + + Parameters + ---------- + op : StringImm + The StringImm to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.ExprVisitorVisitExpr(self._outer(), op) # type: ignore + + def visit_data_type_imm_(self, op: DataTypeImm) -> None: + """Visit DataTypeImm. + Users can customized this function to overwrite VisitExpr_(const DataTypeImmNode* op) + on the C++ side. + + Parameters + ---------- + op : DataTypeImm + The DataTypeImm to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.ExprVisitorVisitExpr(self._outer(), op) # type: ignore + + def visit_var_binding_(self, binding: VarBinding) -> None: + """Visit VarBinding. + Users can customized this function to overwrite VisitBinding_(const VarBindingNode* binding) + on the C++ side. + + Parameters + ---------- + binding : VarBinding + The VarBinding to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.ExprVisitorVisitBinding(self._outer(), binding) # type: ignore + + def visit_match_cast_(self, binding: MatchCast) -> None: + """Visit MatchCast. + Users can customized this function to overwrite VisitBinding_(const MatchCastNode* binding) + on the C++ side. + + Parameters + ---------- + binding : MatchCast + The MatchCast to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.ExprVisitorVisitBinding(self._outer(), binding) # type: ignore + + def visit_binding_block_(self, block: BindingBlock) -> None: + """Visit BindingBlock. + Users can customized this function to overwrite VisitBindingBlock_(const BindingBlockNode* + block) on the C++ side. + + Parameters + ---------- + block : BindingBlock + The BindingBlock to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.ExprVisitorVisitBindingBlock(self._outer(), block) # type: ignore + + def visit_dataflow_block_(self, block: DataflowBlock) -> None: + """Visit DataflowBlock. + Users can customized this function to overwrite VisitBindingBlock_(const DataflowBlockNode* + block) on the C++ side. + + Parameters + ---------- + block : DataflowBlock + The DataflowBlock to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.ExprVisitorVisitBindingBlock(self._outer(), block) # type: ignore + + def visit_var_def_(self, var: Var) -> None: + """Visit the Var definition site. + Users can customized this function to overwrite VisitVarDef_(const VarNode* var) + on the C++ side. + + Parameters + ---------- + var : Var + The Var to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.ExprVisitorVisitVarDef(self._outer(), var) # type: ignore + + def visit_dataflow_var_def_(self, var: DataflowVar) -> None: + """Visit the DataflowVar definition site. + Users can customized this function to overwrite VisitVarDef_(const DataflowVarNode* var) + on the C++ side. + + Parameters + ---------- + var : DataflowVar + The DataflowVar to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.ExprVisitorVisitVarDef(self._outer(), var) # type: ignore + + def visit_span(self, span: Span) -> None: + """Visit Span. + Users can customized this function to overwrite VisitSpan(const Span& span) on the C++ side. + + Parameters + ---------- + span : Span + The Span to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.ExprVisitorVisitSpan(self._outer(), span) # type: ignore + + +@tvm._ffi.register_object("expr_functor.PyExprMutator") +class _PyExprMutator(Object): + """ + A TVM object to support customization of ExprMutator on the python side. + This is the decorated result returned from mutator decorator. + + WARNING: This is NOT the user facing class for method overwriting inheritance. + + See also: mutator, PyExprmutator + """ + + def __init__( + self, + builder: BlockBuilder = None, + f_visit_expr: Callable = None, + f_visit_constant_: Callable = None, + f_visit_tuple_: Callable = None, + f_visit_var_: Callable = None, + f_visit_dataflow_var_: Callable = None, + f_visit_shape_expr_: Callable = None, + f_visit_extern_func_: Callable = None, + f_visit_global_var_: Callable = None, + f_visit_function_: Callable = None, + f_visit_call_: Callable = None, + f_visit_seq_expr_: Callable = None, + f_visit_if_: Callable = None, + f_visit_op_: Callable = None, + f_visit_tuple_getitem_: Callable = None, + f_visit_prim_value_: Callable = None, + f_visit_string_imm_: Callable = None, + f_visit_data_type_imm_: Callable = None, + f_visit_binding: Callable = None, + f_visit_var_binding_: Callable = None, + f_visit_match_cast_: Callable = None, + f_visit_binding_block: Callable = None, + f_visit_binding_block_: Callable = None, + f_visit_dataflow_block_: Callable = None, + f_visit_var_def: Callable = None, + f_visit_var_def_: Callable = None, + f_visit_dataflow_var_def_: Callable = None, + f_visit_span: Callable = None, + ) -> None: + """Constructor.""" + + self.__init_handle_by_constructor__( + _ffi_api.MakePyExprMutator, # type: ignore + builder, + f_visit_expr, + f_visit_constant_, + f_visit_tuple_, + f_visit_var_, + f_visit_dataflow_var_, + f_visit_shape_expr_, + f_visit_extern_func_, + f_visit_global_var_, + f_visit_function_, + f_visit_call_, + f_visit_seq_expr_, + f_visit_if_, + f_visit_op_, + f_visit_tuple_getitem_, + f_visit_prim_value_, + f_visit_string_imm_, + f_visit_data_type_imm_, + f_visit_binding, + f_visit_var_binding_, + f_visit_match_cast_, + f_visit_binding_block, + f_visit_binding_block_, + f_visit_dataflow_block_, + f_visit_var_def, + f_visit_var_def_, + f_visit_dataflow_var_def_, + f_visit_span, + ) + + def visit_expr(self, expr: Expr) -> Expr: + """Generic dispatcher for Expr. + + Parameters + ---------- + expr : Expr + The expr to be visited. + + Returns + ------- + result : Expr + The Expr after transformation. + """ + return _ffi_api.PyExprMutatorVisitExpr(self, expr) # type: ignore + + def visit_binding(self, binding: Binding) -> None: + """Generic dispatcher for Binding. + + Parameters + ---------- + binding : Binding + The binding to be visited. + """ + return _ffi_api.PyExprMutatorVisitBinding(self, binding) # type: ignore + + def visit_binding_block(self, block: BindingBlock) -> BindingBlock: + """Generic dispatcher for BindingBlock. + + Parameters + ---------- + block : BindingBlock + The block to be visited. + + Returns + ------- + result : BindingBlock + The binding block after transformation. + """ + return _ffi_api.PyExprMutatorVisitBindingBlock(self, block) # type: ignore + + def visit_var_def(self, var: Var) -> Var: + """Generic dispatcher for visiting the var definition site. + Note that visit_var_() will only visit the usage site of an Var. + + Parameters + ---------- + var : Var + The var to be visited. + + Returns + ------- + result : Var + The var after post-order rewritten. + """ + return _ffi_api.PyExprMutatorVisitVarDef(self, var) # type: ignore + + +class PyExprMutator: + """ + An abstract ExprMutator with customized methods on the python-side. + This is the user facing class for method overwriting inheritance. + _tvm_metadata discribes the class to inherit("cls"), the methods that users can + overwrite("methods"), the constructor's parameters("fields") + + Note: @relax.expr_functor.mutator is required for proper usage of any inherited class. + + See also: visitor, _PyExprVisitor + + Example: + @relax.expr_functor.mutator + def MyExprMutator(PyExprMutator): + ... + """ + + _tvm_metadata = { + "cls": _PyExprMutator, + "fields": ["builder_"], + "methods": [ + "visit_expr", + "visit_constant_", + "visit_tuple_", + "visit_var_", + "visit_dataflow_var_", + "visit_shape_expr_", + "visit_extern_func_", + "visit_global_var_", + "visit_function_", + "visit_call_", + "visit_seq_expr_", + "visit_if_", + "visit_op_", + "visit_tuple_getitem_", + "visit_prim_value_", + "visit_string_imm_", + "visit_data_type_imm_", + "visit_binding", + "visit_var_binding_", + "visit_match_cast_", + "visit_binding_block", + "visit_binding_block_", + "visit_dataflow_block_", + "visit_var_def", + "visit_var_def_", + "visit_dataflow_var_def_", + "visit_span", + ], + } + + def __init__(self, mod: Optional[IRModule] = None) -> None: + """Constructor""" + self.builder_ = BlockBuilder(mod) + + def visit_expr(self, expr: Expr) -> Expr: + """Generic dispatcher for Expr. + Users can customized this function to overwrite VisitExpr(const Expr& expr) on the C++ side. + + Parameters + ---------- + expr : Expr + The expr to be visited. + + Returns + ------- + result : Expr + The Expr after transformation. + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.PyExprMutatorVisitExpr(self._outer(), expr) # type: ignore + + def visit_binding(self, binding: Binding) -> None: + """Generic dispatcher for Binding. + Users can customized this function to overwrite VisitBinding(const Binding& binding) + on the C++ side. + + Parameters + ---------- + binding : Binding + The binding to be visited. + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.PyExprMutatorVisitBinding(self._outer(), binding) # type: ignore + + def visit_binding_block(self, block: BindingBlock) -> BindingBlock: + """Generic dispatcher for BindingBlock. + Users can customized this function to overwrite VisitBindingBlock(const BindingBlock& block) + on the C++ side. + + Parameters + ---------- + block : BindingBlock + The block to be visited. + + Returns + ------- + result : BindingBlock + The binding block after transformation. + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.PyExprMutatorVisitBindingBlock(self._outer(), block) # type: ignore + + def visit_var_def(self, var: Var) -> Var: + """Generic dispatcher for visiting the var definition site. + Users can customized this function to overwrite VisitVarDef(const Var& var) on the C++ side. + Note that visit_var_() will only visit the usage site of an Var. + + Parameters + ---------- + var : Var + The var to be visited. + + Returns + ------- + result: Var + The var after post-order rewritten. + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.PyExprMutatorVisitVarDef(self._outer(), var) # type: ignore + + def visit_constant_(self, op: Constant) -> Expr: + """Visit Constant. + Users can customized this function to overwrite VisitExpr_(const ConstantNode* op) + on the C++ side. + + Parameters + ---------- + op : Constant + The Constant to be visited. + + Returns + ------- + result : Expr + The Expr after transformation + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.ExprMutatorVisitExpr(self._outer(), op) # type: ignore + + def visit_tuple_(self, op: Tuple) -> Expr: + """Visit Tuple. + Users can customized this function to overwrite VisitExpr_(const TupleNode* op) + on the C++ side. + + Parameters + ---------- + op : Tuple + The Tuple to be visited. + + Returns + ------- + result : Expr + The Expr after transformation + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.ExprMutatorVisitExpr(self._outer(), op) # type: ignore + + def visit_var_(self, op: Var) -> Expr: + """Visit Var. + Users can customized this function to overwrite VisitExpr_(const VarNode* op) + on the C++ side. + + Parameters + ---------- + op : Var + The Var to be visited. + + Returns + ------- + result : Expr + The Expr after transformation + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.ExprMutatorVisitExpr(self._outer(), op) # type: ignore + + def visit_dataflow_var_(self, op: DataflowVar) -> Expr: + """Visit DataflowVar. + Users can customized this function to overwrite VisitExpr_(const DataflowVarNode* op) + on the C++ side. + + Parameters + ---------- + op : DataflowVar + The DataflowVar to be visited. + + Returns + ------- + result : Expr + The Expr after transformation + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.ExprMutatorVisitExpr(self._outer(), op) # type: ignore + + def visit_shape_expr_(self, op: ShapeExpr) -> Expr: + """Visit ShapeExpr. + Users can customized this function to overwrite VisitExpr_(const ShapeExprNode* op) + on the C++ side. + + Parameters + ---------- + op : ShapeExpr + The ShapeExpr to be visited. + + Returns + ------- + result : Expr + The Expr after transformation + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.ExprMutatorVisitExpr(self._outer(), op) # type: ignore + + def visit_extern_func_(self, op: ExternFunc) -> Expr: + """Visit ExternFunc. + Users can customized this function to overwrite VisitExpr_(const ExternFuncNode* op) + on the C++ side. + + Parameters + ---------- + op : ExternFunc + The ExternFunc to be visited. + + Returns + ------- + result : Expr + The Expr after transformation + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.ExprMutatorVisitExpr(self._outer(), op) # type: ignore + + def visit_global_var_(self, op: GlobalVar) -> Expr: + """Visit GlobalVar. + Users can customized this function to overwrite VisitExpr_(const GlobalVarNode* op) + on the C++ side. + + Parameters + ---------- + op : GlobalVar + The GlobalVar to be visited. + + Returns + ------- + result : Expr + The Expr after transformation + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.ExprMutatorVisitExpr(self._outer(), op) # type: ignore + + def visit_function_(self, op: Function) -> Expr: + """Visit Function. + Users can customized this function to overwrite VisitExpr_(const FunctionNode* op) + on the C++ side. + + Parameters + ---------- + op : Function + The Function to be visited. + + Returns + ------- + result : Expr + The Expr after transformation + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.ExprMutatorVisitExpr(self._outer(), op) # type: ignore + + def visit_call_(self, op: Call) -> Expr: + """Visit Call. + Users can customized this function to overwrite VisitExpr_(const CallNode* op) + on the C++ side. + + Parameters + ---------- + op : Call + The Call to be visited. + + Returns + ------- + result : Expr + The Expr after transformation + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.ExprMutatorVisitExpr(self._outer(), op) # type: ignore + + def visit_seq_expr_(self, op: SeqExpr) -> Expr: + """Visit SeqExpr. + Users can customized this function to overwrite VisitExpr_(const SeqExprNode* op) + on the C++ side. + + Parameters + ---------- + op : SeqExpr + The SeqExpr to be visited. + + Returns + ------- + result : Expr + The Expr after transformation + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.ExprMutatorVisitExpr(self._outer(), op) # type: ignore + + def visit_if_(self, op: If) -> Expr: + """Visit If. + Users can customized this function to overwrite VisitExpr_(const IfNode* op) + on the C++ side. + + Parameters + ---------- + op : If + The If to be visited. + + Returns + ------- + result : Expr + The Expr after transformation + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.ExprMutatorVisitExpr(self._outer(), op) # type: ignore + + def visit_op_(self, op: Op) -> Expr: + """Visit Op. + Users can customized this function to overwrite VisitExpr_(const OpNode* op) + on the C++ side. + + Parameters + ---------- + op : Op + The Op to be visited. + + Returns + ------- + result : Expr + The Expr after transformation + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.ExprMutatorVisitExpr(self._outer(), op) # type: ignore + + def visit_tuple_getitem_(self, op: TupleGetItem) -> Expr: + """Visit TupleGetItem. + Users can customized this function to overwrite VisitExpr_(const TupleGetItemNode* op) + on the C++ side. + + Parameters + ---------- + op : TupleGetItem + The TupleGetItem to be visited. + + Returns + ------- + result : Expr + The Expr after transformation + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.ExprMutatorVisitExpr(self._outer(), op) # type: ignore + + def visit_prim_value_(self, op: PrimValue) -> Expr: + """Visit PrimValue. + Users can customized this function to overwrite VisitExpr_(const PrimValueNode* op) + on the C++ side. + + Parameters + ---------- + op : PrimValue + The PrimValue to be visited. + + Returns + ------- + result : Expr + The Expr after transformation + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.ExprMutatorVisitExpr(self._outer(), op) # type: ignore + + def visit_string_imm_(self, op: StringImm) -> Expr: + """Visit StringImm. + Users can customized this function to overwrite VisitExpr_(const StringImmNode* op) + on the C++ side. + + Parameters + ---------- + op : StringImm + The StringImm to be visited. + + Returns + ------- + result : Expr + The Expr after transformation + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.ExprMutatorVisitExpr(self._outer(), op) # type: ignore + + def visit_data_type_imm_(self, op: DataTypeImm) -> Expr: + """Visit DataTypeImm. + Users can customized this function to overwrite VisitExpr_(const DataTypeImmNode* op) + on the C++ side. + + Parameters + ---------- + op : DataTypeImm + The DataTypeImm to be visited. + + Returns + ------- + result : Expr + The Expr after transformation + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.ExprMutatorVisitExpr(self._outer(), op) # type: ignore + + def visit_var_binding_(self, binding: VarBinding) -> None: + """Visit VarBinding. + Users can customized this function to overwrite VisitBinding_(const VarBindingNode* binding) + on the C++ side. + + Parameters + ---------- + binding : VarBinding + The VarBinding to be visited. + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.ExprMutatorVisitBinding(self._outer(), binding) # type: ignore + + def visit_match_cast_(self, binding: MatchCast) -> None: + """Visit MatchCast. + Users can customized this function to overwrite VisitBinding_(const MatchCastNode* binding) + on the C++ side. + + Parameters + ---------- + binding : MatchCast + The MatchCast to be visited. + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.ExprMutatorVisitBinding(self._outer(), binding) # type: ignore + + def visit_binding_block_(self, block: BindingBlock) -> BindingBlock: + """Visit BindingBlock. + Users can customized this function to overwrite VisitBindingBlock_(const BindingBlockNode* + block) on the C++ side. + + Parameters + ---------- + block : BindingBlock + The BindingBlock to be visited. + + Returns + ------- + result : BindingBlock + The binding block after transformation + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.ExprMutatorVisitBindingBlock(self._outer(), block) # type: ignore + + def visit_dataflow_block_(self, block: DataflowBlock) -> BindingBlock: + """Visit DataflowBlock. + Users can customized this function to overwrite VisitBindingBlock_(const DataflowBlockNode* + block) on the C++ side. + + Parameters + ---------- + block : DataflowBlock + The DataflowBlock to be visited. + + Returns + ------- + result : BindingBlock + The binding block after transformation + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.ExprMutatorVisitBindingBlock(self._outer(), block) # type: ignore + + def visit_var_def_(self, var: Var) -> Var: + """Visit the Var definition site. + Users can customized this function to overwrite VisitVarDef_(const VarNode* var) + on the C++ side. + + Parameters + ---------- + var : Var + The Var to be visited. + + Returns + ------- + result : Var + The var after post-order rewritten. + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.ExprMutatorVisitVarDef(self._outer(), var) # type: ignore + + def visit_dataflow_var_def_(self, var: DataflowVar) -> Var: + """Visit the DataflowVar definition site. + Users can customized this function to overwrite VisitVarDef_(const DataflowVarNode* var) + on the C++ side. + + Parameters + ---------- + var : DataflowVar + The DataflowVar to be visited. + + Returns + ------- + result : Var + The var after post-order rewritten. + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.ExprMutatorVisitVarDef(self._outer(), var) # type: ignore + + def visit_span(self, span: Span) -> Span: + """Visit Span. + Users can customized this function to overwrite VisitSpan(const Span& span) on the C++ side. + + Parameters + ---------- + span : Span + The Span to be visited. + + Returns + ------- + result : Span + The span after transformation. + """ + raise NotImplementedError + + def visit_expr_post_order(self, expr: Expr) -> Expr: + """Post-order rewrite an Expr and normalize. + + Parameters + ---------- + expr : Expr + The Expr to be rewritten. + + Returns + ------- + result : Expr + The Expr after post-order rewritten. + """ + return _ffi_api.PyExprMutatorVisitExprPostOrder(self._outer(), expr) # type: ignore + + def set_var_remap(self, vid: Id, var: Var) -> None: + """Remap a var to a new var in use-site. + + Parameters + ---------- + vid : Id + The vid of the old var. + var : Var + The new var. + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.PyExprMutatorSetVarRemap(self._outer(), vid, var) # type: ignore + + def get_var_remap(self, vid: Id) -> Var: + """Remap a var to a new var in use-site. + + Parameters + ---------- + vid : Id + The vid of the old var + + Returns + ------- + var : Var + The remapped var. + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.PyExprMutatorGetVarRemap(self._outer(), vid) # type: ignore + + def visit_with_new_scope(self, expr: Expr) -> Expr: + """Rewrite the expr with a new scope, used in a Function's body and the branches of If. + + Parameters + ---------- + expr : Expr + The expr to be visited. + + Returns + ------- + var : Var + The expr after visiting. + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.PyExprMutatorVisitWithNewScope(self._outer(), expr) # type: ignore + + def lookup_binding(self, var: Var) -> Optional[Expr]: + """Look up the value bound to a variable. + Note: For function parameters, this function returns NullOpt. + + Parameters + ---------- + var : Var + The var to be looked up. + + Returns + ------- + var : Var + The value bound to the input var. + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.PyExprMutatorLookupBinding(self._outer(), var) # type: ignore + + def with_struct_info(self, var: Var, struct_info: StructInfo) -> Var: + """Create a new var with specified shape and type if the original var's shape or type does + not match with the specified ones. + + Parameters + ---------- + var : Var + The var to be updated. + struct_info : StructInfo + The struct info. + + Returns + ------- + var : Var + The var filled with shape and type. + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.PyExprMutatorWithStructInfo(self._outer(), var, struct_info) # type: ignore diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py new file mode 100644 index 000000000000..101b0827d630 --- /dev/null +++ b/python/tvm/relax/op/__init__.py @@ -0,0 +1,22 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=wildcard-import, redefined-builtin +"""Relax core operators.""" + +# Operators +from .base import * +from .binary import * diff --git a/python/tvm/relax/op/_ffi_api.py b/python/tvm/relax/op/_ffi_api.py new file mode 100644 index 000000000000..8dc6a1b4fbb0 --- /dev/null +++ b/python/tvm/relax/op/_ffi_api.py @@ -0,0 +1,19 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +"""FFI APIs for tvm.relax.op""" +import tvm._ffi + +tvm._ffi._init_api("relax.op", __name__) diff --git a/python/tvm/relax/op/base.py b/python/tvm/relax/op/base.py new file mode 100644 index 000000000000..d76b155beb83 --- /dev/null +++ b/python/tvm/relax/op/base.py @@ -0,0 +1,358 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# pylint: disable=redefined-builtin +"""The base Relax operators.""" +from typing import Union, List, Tuple, Optional + + +import tvm +from tvm.runtime.object import Object + +from . import _ffi_api +from ..expr import Expr, ShapeExpr, Call, ExternFunc +from ..expr import Tuple as RxTuple +from ..struct_info import StructInfo, TensorStructInfo +from ...ir import PrimExpr +from ..utils import args_converter + + +py_print = print # pylint: disable=invalid-name + + +def null_value() -> Call: + """Create a call node that represents a null value object. + + Returns + ------- + ret: Call + The created call node. + """ + return _ffi_api.null_value() # type: ignore + + +@args_converter.auto +def call_tir( + func: Union[str, Expr], + args: Expr, + out_sinfo: Union[TensorStructInfo, List[TensorStructInfo]], + tir_vars: Optional[Union[ShapeExpr, Tuple[PrimExpr], List[PrimExpr]]] = None, +) -> Call: + """ + Call a destination-passing-style function and return the output. + + Parameters + ---------- + func : Union[str, Expr] + The destination-passing-style function, can be ExternFunc or PrimFunc. + + args : Expr + The input arguments. + + out_sinfo : Union[TensorStructInfo, List[TensorStructInfo]] + The structure info of the call_tir output. + It should be a single or a list of TensorStructInfo. Each one denotes the + structure info of a returned tensor. + + tir_vars : Optional[Union[ShapeExpr, Tuple[PrimExpr], List[PrimExpr]]] + ShapeExpr representing a tuple of integers to unpack when calling func. Is null if not used + + Returns + ------- + ret: Call + A call node for the call_tir operator. + """ + if isinstance(func, str): + func = ExternFunc(func) + + if isinstance(args, Expr) and not isinstance(args, RxTuple): # type: ignore + args = RxTuple((args,)) + + if not isinstance(out_sinfo, list): + out_sinfo = [out_sinfo] + + if isinstance(tir_vars, (list, tuple)): + tir_vars = ShapeExpr(tir_vars) + + return _ffi_api.call_tir(func, args, out_sinfo, tir_vars) # type: ignore + + +@args_converter.auto +def call_builtin_with_ctx( + func: Union[str, Expr], + args: Expr, + *, + sinfo_args: Optional[Union[StructInfo, List[StructInfo]]] = None, +) -> Call: + """Call a builtin function func. + + Parameters + ---------- + func : Expr + The builtin function to be called. + + args : Expr + The input arguments. + + sinfo_args: Optional[Union[StructInfo, List[StructInfo]]] + The struct info arguments to the call node. + + Returns + ------- + ret: Call + The created call node. + """ + if isinstance(func, str): + func = ExternFunc(func) + + if sinfo_args is not None and not isinstance(sinfo_args, (list, tuple)): + sinfo_args = [sinfo_args] + + return _ffi_api.call_builtin_with_ctx( # type: ignore + func, + args, + sinfo_args, # type: ignore + ) + + +@args_converter.auto +def make_closure( + func: Expr, + args: Expr, +) -> Object: + """ + Create a closure with free variables and return the closure. + + Parameters + ---------- + func : Expr + The closure, can be ExternFunc or PrimFunc. + + args : Expr + The input arguments. + + + Returns + ------- + ret: Object + The VMClosure. + """ + + return _ffi_api.make_closure(func, args) # type: ignore + + +@args_converter.auto +def invoke_closure( + closure: Expr, + args: Expr, + sinfo_args: Union[List[StructInfo], StructInfo], +) -> Object: + """ + Invoke a closure. + + Parameters + ---------- + closure : Expr + The VMClosure object. + + args : Expr + The input arguments. + + type_args: Union[List[StructInfo], StructInfo] + The structure info arguments of the CallNode + + Returns + ------- + ret: Object + The result. + """ + + if not isinstance(sinfo_args, (list, tuple)): + sinfo_args = [sinfo_args] + + return _ffi_api.invoke_closure(closure, args, sinfo_args) # type: ignore + + +def render_object(val: tvm.Object) -> str: + """ + Given a TVM Object, renders it in string form. Used for Relax printing and assertions. + + Parameters + ---------- + val: tvm.Object + An object to render + + Returns + ------- + ret: str + A string representing the value, ideally human-readable + """ + if isinstance(val, tvm.runtime.ndarray.NDArray): + return str(val) + # no pretty-printer by default, so if we don't handle this, + # then we can't look inside tuples + if isinstance(val, tvm.runtime.container.ADT): + # the fields array of an ADT cannot be directly accessed in Python + # so we have to get the length and index into the fields separately + fields = ", ".join([render_object(val[i]) for i in range(len(val))]) + # special case: tag = 0 is a tuple + if val.tag == 0: + return f"({fields})" + return f"ADT(tag={val.tag}, fields=[{fields}])" + return str(val) + + +@tvm.register_func("relax.run.print") +def relax_print(format_str: str, *format_args: tvm.Object) -> None: + """ + Takes a list of values to print, formats with the given format string. + If the format string is empty, simply prints. + + Call from TVM script like this: + `relax.print(value1, value2, ..., valueN, format=format_str)` + or + `relax.print(value1, value2, ..., valueN) # format_str defaults to ""` + + Parameters + ---------- + format_str: str + The last argument is a Python-style format string for printing the value + + format_args: List[Object] + The values to print. + """ + val_strs = map(render_object, format_args) + if format_str == "": + py_print(*val_strs) + else: + py_print(format_str.format(*val_strs)) + + +def print(*values: List[Expr], format: str = "") -> Expr: + """Print op to print the values + + Parameters + ---------- + values : List[Expr] + The values to print. + + format_str: str + The format string. + + Returns + ------- + result : Expr + A relax Call, which will print the value during runtime. + """ + return _ffi_api.print(values, format) # type: ignore # pylint: disable=no-member + + +@tvm.register_func("relax.run.assert_op") +def relax_assert_op(condition: tvm.Object, format_str: str, *format_args: tvm.Object) -> None: + """ + A variadic function. The first value serves as the assertion condition: + If the condition is true, then the operator does nothing. + If the condition is false, then the operator raises an assertion error. + + Arguments after the first value serve as format arguments for the error message; + the last argument must be a format string for the error message (empty by default). + If the format string is the empty string, then the error message will simply include + a comma-separated list of the format arguments. + The condition argument is not included in the format string. + + Parameters + ---------- + condition: tvm.Object + The assertion condition. Must be a boolean scalar. + + format_str: str + The last argument is a Python-style format string for printing the value + + format_args: List[tvm.Object] + Values used for formatting the string. + """ + if not isinstance(format_str, str): + raise ValueError( + f"The format string argument to assert must be a string, given {type(format_str)})" + ) + + # should be guaranteed by the type system + if not isinstance(condition, tvm.runtime.ndarray.NDArray): + raise ValueError(f"The condition must be an NDArray, but given a {type(condition)}.") + + # may happen if the original program had unknown shape or dtype for the tensor's type + dtype = condition.dtype + if dtype != "bool": + raise ValueError(f"The condition must be a bool scalar, but given a {dtype} tensor") + shape = condition.shape + if len(shape) != 0: + raise ValueError(f"The condition must be a scalar, but it has a shape of {shape}") + + val = condition.numpy() + if not val: + error_message = "Assertion Failed" + if format_args or format_str != "": + rendered = map(render_object, format_args) + if format_str != "": + error_message = format_str.format(*rendered) + else: + error_message = ", ".join(rendered) + raise AssertionError(error_message) + + +def assert_op( + condition: Expr, format_args: Optional[Union[Expr, List[Expr]]] = None, format: str = "" +) -> Expr: + """ + Create a call to Relax's assert_op operation (`assert` is reserved in Python, + so the name must be distinct). + + Parameters + ---------- + condition: Expr + The assertion condition. + + format_args: Optional[Union[Expr, List[Expr]]] + Format arguments for the error message if the condition fails. + + format_str: str + The format string for the error message. + + Returns + ------- + result : Expr + A Call to the Relax assert operation. + """ + if format_args is None: + format_args = [] + if isinstance(format_args, Expr): # type: ignore + format_args = [format_args] + return _ffi_api.assert_op(condition, format_args, format) # type: ignore + + +def shape_of(expr: Expr) -> Expr: + """Get shape of a tensor. + + Parameters + ---------- + expr : Expr + The input Expr. + + Returns + ------- + result : Expr + A relax Call, which gets the shape of the input + """ + return _ffi_api.shape_of(expr) # type: ignore # pylint: disable=no-member diff --git a/python/tvm/relax/op/binary.py b/python/tvm/relax/op/binary.py new file mode 100644 index 000000000000..eee0b6f3366a --- /dev/null +++ b/python/tvm/relax/op/binary.py @@ -0,0 +1,67 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=redefined-builtin, invalid-name +"""Relax binary arithmetic and comparison operators.""" +from . import _ffi_api +from ..expr import Expr + +###################### Arithmetic operators ###################### + + +def add(x1: Expr, x2: Expr) -> Expr: + """Addition with numpy-style broadcasting. + + Parameters + ---------- + x1 : Expr + The first input tensor. + x2 : Expr + The second input tensor. + + Returns + ------- + result : Expr + The computed result. + + Examples + -------- + .. code:: python + + bb = relax.BlockBuilder() + a = relax.Var("a", relax.TensorStructInfo(shape=(2, 3), dtype="float32")) + b = relax.Var("b", relax.TensorStructInfo(shape=(2, 1), dtype="float32")) + c = bb.normalize(relax.op.add(a, b)) # c has TensorStructInfo(shape=(2, 3), dtype="float32") + """ + return _ffi_api.add(x1, x2) # type: ignore + + +def multiply(x1: Expr, x2: Expr) -> Expr: + """Multiplication with numpy-style broadcasting. + + Parameters + ---------- + x1 : Expr + The first input tensor. + x2 : Expr + The second input tensor. + + Returns + ------- + result : Expr + The computed result. + """ + return _ffi_api.multiply(x1, x2) # type: ignore diff --git a/python/tvm/relax/utils.py b/python/tvm/relax/utils.py new file mode 100644 index 000000000000..5bfb0d87bf00 --- /dev/null +++ b/python/tvm/relax/utils.py @@ -0,0 +1,278 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Utility functions for Relax""" +import functools +import inspect +from typing import Any, Callable, List, Optional, TypeVar + +from .. import tir +from ..runtime import String, convert_to_object +from ..tir import PrimExpr +from . import _ffi_api +from .expr import Expr, Function, PrimValue, ShapeExpr, StringImm +from .expr import Tuple as rx_Tuple + + +def metadata_partitioner(rx_txt: str) -> List[str]: + """Extract Relax program and metadata section. + + Parameters + ---------- + rx_txt : str + The input relax text. + + Returns + ------- + output : List[str] + The result list of partitioned text, the first element + is the relax program, and the second is metadata section. + """ + partitions = [] + left_curly = 0 + meta_start = 0 + meta_end = 0 + for i, char in enumerate(rx_txt): + if i < 0: + raise ValueError("The program is invalid.") + if char == "{": + if meta_start == 0: + meta_start = i + left_curly += 1 + elif char == "}": + left_curly -= 1 + if left_curly == 0: + meta_end = i + 1 + break + + if meta_end == 0: + raise ValueError("The metadata section was not found.") + metadata = rx_txt[meta_start:meta_end] + rx_program = rx_txt[meta_end:-1] + + partitions.append(rx_program) + partitions.append(metadata) + + return partitions + + +def convert_to_expr(value: Any) -> Expr: + """Helper function to convert the input to Expr, which follows the rules: + 1. Return the input itself if it's already a `relax.Expr`; + 2. Return `relax.PrimValue` if the input is a `PrimExpr`; + 3. Return `relax.StringImm` if the input is `tvm.String` or `str`; + 4. Return `relax.ShapeExpr` if the input is a tuple/list of `PrimExpr` w/ int dtype; + 5. Return `relax.Tuple` if the input is a tuple/list of `Expr`. + + Notes + ----- + 1. `tvm.tir.StringImm` is not allowed because of ambiguity, + which can be either `relax.StringImm` or `relax.PrimValue`. + 2. We regard empty tuple/list as `relax.Tuple` instead of `relax.ShapeExpr` + """ + if isinstance(value, int): + return PrimValue(tir.IntImm("int64", value)) + + tvm_value = convert_to_object(value) + # Case 1 + if isinstance(tvm_value, Expr): # type: ignore + return tvm_value + # Note`` 1 + if isinstance(tvm_value, tir.StringImm): + raise TypeError( + "Cannot convert `tir.StringImm` to `relax.Expr` because of ambiguity," + "which can be either `relax.StringImm` or `relax.PrimValue` " + ) + # Case 2 + if isinstance(tvm_value, PrimExpr): + return PrimValue(value) + # Case 3 + if isinstance(tvm_value, String): + return StringImm(value) + # Case 4 & 5 + if isinstance(value, (tuple, list)): + # Note 2 + if len(value) == 0: + return rx_Tuple([]) + # Case 4 + opt_prim_value = [convert_to_object(v) for v in value] + if all([isinstance(v, PrimExpr) and v.dtype.startswith("int") for v in opt_prim_value]): + return ShapeExpr(value) + # Case 5 + # `convert_to_expr` ensures that all elements are `Expr` if no exception raises + return rx_Tuple([convert_to_expr(v) for v in value]) + raise TypeError(f"Cannot convert {value} with type {type(value)} to `relax.Expr`") + + +FType = TypeVar("FType", bound=Callable[..., Expr]) + + +class _ArgsConverter: + """A helper class to convert the arguments to Expr.""" + + @staticmethod + def convert(args_to_expr: List[str], args_to_list_expr: List[str]): + """Convert the arguments to Expr. + + Parameters + ---------- + args_to_expr : List[str] + The argument names to be converted to Expr. + + args_to_list_expr : List[str] + The argument names to be converted to List[Expr]. + + Returns + ------- + output : Callable[[FType], FType] + The decorator. + """ + + if any([x in args_to_list_expr for x in args_to_expr]): + raise ValueError(f"`args_to_expr` and `args_to_list_expr` should be disjoint.") + + def _convert(name: str, value: Any) -> Any: + if value is None: + return value + if name in args_to_expr: + try: + return convert_to_expr(value) + except: + raise TypeError( + f"Argument `{name}` is expected to be converted to `Expr`, " + f"but failed with input value: {value}" + ) + elif name in args_to_list_expr: + try: + return [convert_to_expr(x) for x in value] + except: + raise TypeError( + f"Argument `{name}` is expected to be converted to `List[Expr]`, " + f"but failed with input value: {value}" + ) + else: + return value + + def inner(func: FType) -> FType: + sig = inspect.signature(func) + param_names = list(sig.parameters.keys()) + for name in args_to_expr + args_to_list_expr: + if name not in param_names: + raise ValueError(f"Argument `{name}` is not found in function signature.") + + @functools.wraps(func) + def wrapper(*args, **kwargs): + bound = sig.bind(*args, **kwargs) + bound.apply_defaults() + for param in sig.parameters.values(): + if param.kind == param.VAR_POSITIONAL: + # *args case + values = [_convert(param.name, x) for x in bound.arguments[param.name]] + bound.arguments[param.name] = tuple(values) + elif param.kind == param.VAR_KEYWORD: + # **kwargs case + key_value = { + key: _convert(param.name, value) + for key, value in bound.arguments[param.name].items() + } + bound.arguments[param.name] = key_value + else: + bound.arguments[param.name] = _convert( + param.name, bound.arguments[param.name] + ) + return func(*bound.args, **bound.kwargs) + + return wrapper # type: ignore + + return inner + + @staticmethod + def to_expr(*arg_names: str) -> Callable: + """Convert the arguments to Expr. + + Parameters + ---------- + *arg_names: str + The list of argument names that need to be converted to Expr. + + Returns + ------- + output: Callable + The decorator. + """ + + return _ArgsConverter.convert(args_to_expr=list(arg_names), args_to_list_expr=[]) + + @staticmethod + def to_list_expr(*arg_names: str) -> Callable: + """Convert the arguments to List of Expr. + + Parameters + ---------- + *arg_names: str + The list of argument names that need to be converted to List of Expr. + + Returns + ------- + output: Callable + The decorator. + """ + + return _ArgsConverter.convert(args_to_expr=[], args_to_list_expr=list(arg_names)) + + @staticmethod + def auto(func: FType) -> FType: + """Decorator for automatically convert the arguments to Expr according to type annotation. + Only two patterns are supported: + + 1. The argument is Expr or Optional[Expr]. + + 2. The argument is List[Expr] or Optional[List[Expr]]. + + """ + sig = inspect.signature(func) + args_to_expr = [] + args_to_list_expr = [] + + for param in sig.parameters.values(): + anno = param.annotation + if anno in (Expr, Optional[Expr]): + args_to_expr.append(param.name) + if anno in (List[Expr], Optional[List[Expr]]): + args_to_list_expr.append(param.name) + + return _ArgsConverter.convert(args_to_expr, args_to_list_expr)(func) + + +args_converter = _ArgsConverter() # pylint: disable=invalid-name + + +def copy_with_new_params(func: Function) -> Function: + """Copy the given function. The parameters of the original function would be copied to + satisfy the restriction in the well-formed check: any two functions cannot share the same + parameter variable. + + Parameters + ---------- + func : Function + The relax function to copy. + + Returns + ------- + ret : Function + The copied function. + """ + return _ffi_api.CopyWithNewParams(func) # type: ignore diff --git a/python/tvm/te/__init__.py b/python/tvm/te/__init__.py index 0907ea2ebf85..40fac0f92f6d 100644 --- a/python/tvm/te/__init__.py +++ b/python/tvm/te/__init__.py @@ -41,6 +41,7 @@ from .operation import placeholder, compute, scan, extern, var, size_var, const from .operation import thread_axis, reduce_axis from .operation import create_prim_func +from .operation import create_relax_prim_func from .operation import extern_primfunc from .tensor import PlaceholderOp, ComputeOp, TensorComputeOp, ScanOp, ExternOp, HybridOp diff --git a/python/tvm/te/operation.py b/python/tvm/te/operation.py index ae3ad7ca892a..1779f6efc595 100644 --- a/python/tvm/te/operation.py +++ b/python/tvm/te/operation.py @@ -19,7 +19,7 @@ # pylint: disable=invalid-name from numbers import Integral as _Integral -from typing import List, Optional +from typing import List, Optional, Union import tvm._ffi import tvm.arith._ffi_api @@ -571,12 +571,64 @@ def create_prim_func( ops: List[_tensor.Tensor], index_dtype_override: Optional[str] = None ) -> tvm.tir.PrimFunc: """Create a TensorIR PrimFunc from tensor expression + Parameters + ---------- + ops : List[Tensor] + The source expression. + Example + ------- + We define a matmul kernel using following code: + .. code-block:: python + import tvm + from tvm import te + from tvm.te import create_prim_func + import tvm.script + A = te.placeholder((128, 128), name="A") + B = te.placeholder((128, 128), name="B") + k = te.reduce_axis((0, 128), "k") + C = te.compute((128, 128), lambda x, y: te.sum(A[x, k] * B[y, k], axis=k), name="C") + func = create_prim_func([A, B, C]) + print(func.script()) + If we want to use TensorIR schedule to do transformations on such kernel, + we need to use `create_prim_func([A, B, C])` to create a schedulable PrimFunc. + The generated function looks like: + .. code-block:: python + @T.prim_func + def tir_matmul(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + B = T.match_buffer(b, (128, 128)) + C = T.match_buffer(c, (128, 128)) + for i, j, k in T.grip(128, 128, 128): + with T.block(): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] += A[vi, vk] * B[vj, vk] + Returns + ------- + func : tir.PrimFunc + The created function. + """ + if not isinstance(ops, (list, tuple, Array)): + ops = [ops] + return _ffi_api.CreatePrimFunc(ops, index_dtype_override) + + +def create_relax_prim_func( + ops: List[_tensor.Tensor], + tir_var_list: List[tvm.tir.Var] = None, + index_dtype_override: Optional[str] = None, +) -> tvm.tir.PrimFunc: + """Create a TensorIR PrimFunc from tensor expression Parameters ---------- ops : List[Tensor] The source expression. + tir_var_list: List[Var] + TIR variables to add as parameters to generated PrimFunc + Example ------- We define a matmul kernel using following code: @@ -621,4 +673,4 @@ def tir_matmul(a: T.handle, b: T.handle, c: T.handle) -> None: """ if not isinstance(ops, (list, tuple, Array)): ops = [ops] - return _ffi_api.CreatePrimFunc(ops, index_dtype_override) + return _ffi_api.CreateRelaxPrimFunc(ops, tir_var_list, index_dtype_override) diff --git a/src/ir/function.cc b/src/ir/function.cc index 69752f529a3c..6a7ccc7cf27b 100644 --- a/src/ir/function.cc +++ b/src/ir/function.cc @@ -46,4 +46,18 @@ TVM_REGISTER_GLOBAL("ir.BaseFuncWithAttr") } }); +TVM_REGISTER_GLOBAL("ir.BaseFuncWithoutAttr") + .set_body_typed([](BaseFunc func, String key) -> BaseFunc { + if (func->IsInstance()) { + return WithoutAttr(Downcast(std::move(func)), key); + } else if (func->IsInstance()) { + return WithoutAttr(Downcast(std::move(func)), key); + } else if (func->IsInstance()) { + return WithoutAttr(Downcast(std::move(func)), key); + } else { + LOG(FATAL) << "Do not support function type " << func->GetTypeKey(); + return func; + } + }); + } // namespace tvm diff --git a/src/relax/analysis/struct_info_analysis.cc b/src/relax/analysis/struct_info_analysis.cc index d9b139753455..2de06fe5d6f2 100644 --- a/src/relax/analysis/struct_info_analysis.cc +++ b/src/relax/analysis/struct_info_analysis.cc @@ -533,6 +533,155 @@ TVM_REGISTER_GLOBAL("relax.StructInfoIsBaseOf") return IsBaseOf(base, derived); }); +//-------------------------- +// DeriveStructInfo +//-------------------------- + +// NOTE: we are reusing StructInfoBaseChecker here to populate a mapping +// from the expressions in arg(rhs) to var in param. +class CallRetStructInfoDeriver : public StructInfoBaseChecker { + public: + explicit CallRetStructInfoDeriver(arith::Analyzer* ana) : StructInfoBaseChecker(ana) {} + + // No short cut, so we can recursively populate all pairs. + BaseCheckResult VisitStructInfo(const StructInfo& lhs, const StructInfo& other) final { + return StructInfoFunctor::VisitStructInfo(lhs, other); + } + + StructInfo Derive(const FuncStructInfo& finfo, const Call& call, const BlockBuilder& ctx) { + // opaque derivation + if (finfo->IsOpaque()) { + if (finfo->derive_func.defined()) { + // derive using custom derivation function. + return finfo->derive_func.value()(call, ctx); + } else { + // directly return the normal value. + return finfo->ret; + } + } + + // Normal function signature derivation. + auto params = finfo->params.value(); + if (params.size() != call->args.size()) { + ctx->ReportFatal(Diagnostic::Error(call->span) + << "number of arguments and parameters mismatch:" + << " expected " << params.size() << ", given " << call->args.size()); + } + // Visit each param arg pair, check and populate the var map + for (size_t i = 0; i < params.size(); ++i) { + auto arg_sinfo = GetStructInfo(call->args[i]); + BaseCheckResult res = this->VisitStructInfo(params[i], arg_sinfo); + // Report error if we find L1 level failure + // L2 level is best effort so we don't report. + // The behavior of L2 can be customized later. + if (res == BaseCheckResult::kFailL0 || res == BaseCheckResult::kFailL1) { + ctx->ReportFatal(Diagnostic::Error(call->span) + << "Argument " << i << " type mismatch:" + << " expected " << params[i] << ", given " << arg_sinfo); + } + } + // map the ret using the populated var map. + return EraseToWellDefined(finfo->ret, shape_var_map_, var_map_); + } + + protected: + // Whether to populate map in params. + bool populate_mapping_{true}; + // for simplicity, we make these fields public so the user can access them. + Map shape_var_map_; + Map var_map_; + + using StructInfoBaseChecker::ShapeMatchCheck; + + // Match shape values in between param(lhs) and arg(rhs) + BaseCheckResult PrimValueMatchCheck(const PrimExpr& param, const PrimExpr& arg) final { + if (!populate_mapping_) { + return StructInfoBaseChecker::PrimValueMatchCheck(param, arg); + } + + if (auto* ptr = param.as()) { + auto var = GetRef(ptr); + auto it = shape_var_map_.find(var); + // not populated + if (it == shape_var_map_.end()) { + shape_var_map_.Set(var, arg); + return BaseCheckResult::kPass; + } else { + // Best effort prove. + PrimExpr mapped_value = (*it).second; + if (analyzer_->CanProveEqual(mapped_value, arg)) return BaseCheckResult::kPass; + return BaseCheckResult::kFailL2; + } + } else { + // Best effort + // Do not attempt to do prove when param contains a symbolic expr. + // such expression might depends on a later defined var in params created by dyn fusion. + // example: f(a: Tensor[(n+1)], s: Shape[(n,)]), the (n+1) case here. + return StructInfoBaseChecker::PrimValueMatchCheck(param, arg); + } + } + + BaseCheckResult ShapeMatchCheck(const Expr& lhs, const Expr& rhs) final { + if (!populate_mapping_) { + return StructInfoBaseChecker::ShapeMatchCheck(lhs, rhs); + } + + if (auto* ptr = lhs.as()) { + auto var = GetRef(ptr); + auto it = var_map_.find(var); + // not populated + if (it == var_map_.end()) { + var_map_.Set(var, rhs); + return BaseCheckResult::kPass; + } else { + // Best effort prove. + Expr mapped_value = (*it).second; + if (CanProveShapeEqual(mapped_value, rhs, analyzer_)) return BaseCheckResult::kPass; + return BaseCheckResult::kFailL2; + } + } + auto lhs_shape = lhs.as(); + auto rhs_shape = rhs.as(); + ICHECK(lhs_shape) << "lhs must have a shape"; + if (!rhs_shape) return BaseCheckResult::kFailL2; + return ShapeMatchCheck(lhs_shape->values, rhs_shape->values); + } + + BaseCheckResult FuncParamsCheck(const Array& lhs, + const Array& rhs) final { + // Set populate mapping to false + // so we do not pick up symbolic vars in params with function type. + // + // @R.function + // def f(g: R.Func([R.Tensor[(n,)]], R.Tensor[(n+1,)]), + // x: R.Tensor[(m,)]) -> R.Tensor[(m,)]: + // ... + // + // For example, in the above function f, we should avoid + // pick up n in g's signature. + bool populate_mapping = false; + std::swap(populate_mapping_, populate_mapping); + auto ret = StructInfoBaseChecker::FuncParamsCheck(lhs, rhs); + std::swap(populate_mapping_, populate_mapping); + return ret; + } +}; + +StructInfo DeriveCallRetStructInfo(const FuncStructInfo& finfo, const Call& call, + const BlockBuilder& ctx, arith::Analyzer* ana) { + if (ana == nullptr) { + arith::Analyzer inst; + return CallRetStructInfoDeriver(&inst).Derive(finfo, call, ctx); + } else { + return CallRetStructInfoDeriver(ana).Derive(finfo, call, ctx); + } +} + +TVM_REGISTER_GLOBAL("relax.analysis.DeriveCallRetStructInfo") + .set_body_typed([](const FuncStructInfo& finfo, const Call& call, const BlockBuilder& ctx) { + return DeriveCallRetStructInfo(finfo, call, ctx); + }); + //-------------------------- // UnifyToLCA //-------------------------- diff --git a/src/relax/ir/block_builder.cc b/src/relax/ir/block_builder.cc new file mode 100644 index 000000000000..6a2d7ea5c584 --- /dev/null +++ b/src/relax/ir/block_builder.cc @@ -0,0 +1,969 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relax/block_builder.cc + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +// Block builder have three categories of logics that are interdependent with each other. +// +// The logics are somewhat interdependent with each other. +// To help us implement a block builder in two parts: +// +// - BlockBuilderImpl: implements ctx and scope management, with no normalization. +// - BlockBuilderImplWithNormalize: subclasses BlockBuilderImpl and implements normalization. +// +// The final blockbuilder create will be backed by BlockBuilderWithNormalize + +namespace tvm { +namespace relax { + +//--------------------------------------- +// ctx and scope management. +//--------------------------------------- +class BlockBuilderImpl : public BlockBuilderNode { + public: + explicit BlockBuilderImpl(IRModule context_mod) : context_mod_(std::move(context_mod)) {} + + ~BlockBuilderImpl() { + if (!block_stack_.empty()) { + LOG(WARNING) << "BlockBuilder destroyed with remaining blocks!"; + } + } + + //------------------------------- + // Global Context management + //------------------------------- + NameTable* name_table() final { return name_table_.get(); } + + IRModule GetContextIRModule() const final { return context_mod_; } + + GlobalVar AddFunction(const BaseFunc& func, String func_name_hint) final { + LazyInitCtxFuncDedupMap(); + auto it = ctx_func_dedup_map_->find(func); + if (it == ctx_func_dedup_map_->end()) { + context_mod_.CopyOnWrite(); + + String func_name = name_table_->GetUniqueName(func_name_hint); + while (context_mod_->ContainGlobalVar(func_name)) { + func_name = name_table_->GetUniqueName(func_name_hint); + } + GlobalVar gvar = GlobalVar(func_name); + + StructInfo finfo; + if (func->struct_info_.defined()) { + finfo = GetStructInfo(func); + } else if (auto* prim_func = func.as()) { + // NOTE: use a slightly different struct info than checked type + // in PrimFunc so handle can turn into Tensor. + // TODO(relax-team): add fine-grained PrimFunc struct info signature generation. + finfo = FuncStructInfo::OpaqueFunc(StructInfoFromType(prim_func->ret_type)); + } else { + finfo = StructInfoFromType(func->checked_type()); + } + UpdateStructInfo(gvar, finfo); + + context_mod_->Add(gvar, func); + + ctx_func_dedup_map_->emplace(func, gvar); + return gvar; + } else { + return it->second; + } + } + + void UpdateFunction(const GlobalVar& gv, BaseFunc function) final { + context_mod_.CopyOnWrite(); + + // invalidate old dedup map + if (ctx_func_dedup_map_ != nullptr) { + auto it = context_mod_->functions.find(gv); + if (it != context_mod_->functions.end()) { + BaseFunc old_func = (*it).second; + auto ptr = ctx_func_dedup_map_->find(old_func); + ICHECK(ptr != ctx_func_dedup_map_->end()); + ctx_func_dedup_map_->erase(ptr); + } + } + + context_mod_->Update(gv, function); + + // add new dedup map item. + if (ctx_func_dedup_map_ != nullptr) { + ctx_func_dedup_map_->emplace(function, gv); + } + } + + void ReportFatal(const Diagnostic& diagnostic) final { + // TODO(relax-team): Print more context information by looking + // into the diagnostic->loc and surrounding IRModule. + // We do not materialzie DiagnosticContext to avoid double referencing to + // the change IRModule in COW. Additionally, we need to be able to + // continue use the builder after an error is thrown to avoid state building up. + // in an interactive environment. + LOG(FATAL) << diagnostic->message; + } + + //------------------------------- + // Scope management + //------------------------------- + Optional LookupBinding(const Var& var) final { + auto it = binding_table_.find(var->vid); + if (it == binding_table_.end()) return NullOpt; + return it->second; + } + + void BeginDataflowBlock() final { block_stack_.emplace_back(BlockFrame{{}, true}); } + + void BeginBindingBlock() final { block_stack_.emplace_back(BlockFrame{{}, false}); } + + void BeginScope(Optional> params) final { + // The current implementation handles the collection of shape var + // defined in parameter struct info annotations. The implementation + // is correct (since we will simply erase all relax Vars in EraseToWellDefined), + // but can be further improved. + // + // TODO(relax-team): Add support for relax Var in struct info annotations. + Map shape_var_map; + for (const Var& var : params.value_or(Array())) { + const Map& var_map = StructInfoVarCollector::Collect(GetStructInfo(var)); + for (const auto& kv : var_map) { + const tir::Var& shape_var = kv.first; + const PrimExpr& shape_expr = kv.second; + auto it = shape_var_map.find(shape_var); + if (it == shape_var_map.end()) { + shape_var_map.Set(shape_var, shape_expr); + } else { + const PrimExpr& old_shape_expr = (*it).second; + CHECK(analyzer_.CanProveEqual(old_shape_expr, shape_expr)) + << "Inconsistent shape var " << shape_var << " in scope: " << old_shape_expr << " vs " + << shape_expr; + } + shape_var_map.Set(shape_var, shape_expr); + } + } + scope_stack_.emplace_back(ScopeFrame({std::move(shape_var_map)})); + } + + void EndScope() final { scope_stack_.pop_back(); } + + BindingBlock EndBlock() final { + BlockFrame* cur_frame = CurrentBlockFrame(); + BindingBlock ret = cur_frame->is_dataflow ? DataflowBlock(cur_frame->bindings) + : BindingBlock(cur_frame->bindings); + block_stack_.pop_back(); + return ret; + } + + bool CurrentBlockIsDataFlow() final { return CurrentBlockFrame()->is_dataflow; } + + Var Emit(Expr expr, String name_hint) final { + return this->Emit(expr, CurrentBlockFrame()->is_dataflow, name_hint); + } + + Var EmitMatchCast(Expr value, StructInfo struct_info, String name_hint) final { + value = this->Normalize(value); + + CHECK(StructInfoBaseCheck(GetStructInfo(value), struct_info) != BaseCheckResult::kFailL0) + << "It is impossible to match cast any value into the target struct_info. " + "But got value struct info: " + << GetStructInfo(value) << ", given struct info: " << struct_info; + + // NOTE: do match cast checking later in a pass. + BlockFrame* cur_frame = CurrentBlockFrame(); + Var var = CreateVar(cur_frame->is_dataflow, name_hint); + UpdateStructInfo(var, struct_info); + + MatchCast match_cast(var, value, struct_info); + cur_frame->bindings.push_back(match_cast); + // NOTE match shape do not follow simple binding rule + // as a result should not appear in binding table. + return var; + } + + Var EmitOutput(Expr output, String name_hint) final { + BlockFrame* cur_frame = CurrentBlockFrame(); + + ICHECK(cur_frame->is_dataflow) << "EmitOutput has to be called inside dataflow block."; + + return Emit(output, false, name_hint); + } + + void EmitNormalized(Binding binding) final { + BlockFrame* cur_frame = CurrentBlockFrame(); + + if (const auto* var_binding = binding.as()) { + if (!cur_frame->is_dataflow) { + ICHECK(!var_binding->var.as()) + << "Cannot emit dataflow var in non-dataflow block"; + } + // normalized check + ICHECK(var_binding->var->struct_info_.defined()); + ICHECK(var_binding->value->struct_info_.defined()); + cur_frame->bindings.push_back(binding); + binding_table_[var_binding->var->vid] = var_binding->value; + } else if (const auto* match_cast = binding.as()) { + if (!cur_frame->is_dataflow) { + ICHECK(!match_cast->var.as()) + << "Cannot emit dataflow var in non-dataflow block"; + } + // normalized check + ICHECK(match_cast->var->struct_info_.defined()); + ICHECK(match_cast->value->struct_info_.defined()); + // NOTE match shape do not follow simple binding rule + // as a result should not appear in binding table. + cur_frame->bindings.push_back(binding); + } else { + LOG(FATAL) << "Unsupported binding type: " << binding->GetTypeKey(); + } + } + + arith::Analyzer* GetAnalyzer() final { return &analyzer_; } + + protected: + /*! + * \brief A representation of a block frame. + * + * A block frame is a record containing the bindings needed + * to build a binding block, and a boolean to indicate if the + * block being built is a DataflowBlock or not. + */ + struct BlockFrame { + /*! + * \brief List of bindings + */ + Array bindings; + /*! \brief Whether current block is dataflow block. */ + bool is_dataflow; + /*! + * \brief Binding map used by normalizer. + * + * \note The normalizer only caches reuse in the current block scope + * and will not cache bindings from parent scope. + */ + std::unordered_map normalize_binding_map; + }; + /*! + * \brief A representation of a scope frame. + * + * A scope frame records tracks the context of current scope. + */ + struct ScopeFrame { + // NOTE: for simplicity, only tracks symbolic var for now + // the scope is only used for erasure, so less information means + // more conservative analysis. + // Consider impl alternative: merge with block frame if we have more frame kinds. + // + // TODO(relax-team) tracks the var defined also through match-cast. + /*! \brief set of defined symbolic vars, value as themself. */ + Map shape_var_map; + }; + + /*! \brief A stack to store block frames. */ + std::vector block_stack_; + + /*! \brief A stack to store scope frames. */ + std::vector scope_stack_; + + /*! \brief A binding table that maps var to value. */ + std::unordered_map binding_table_; + + /*! \brief A name table to get unique names for IR construction. */ + std::unique_ptr name_table_ = std::make_unique(); + + /*! \brief The IRModule being built by the BlockBuilder. */ + IRModule context_mod_; + + /*! \brief Internal analzyer */ + arith::Analyzer analyzer_; + + /*! + * \return The current frame. + * \note Never hold the value of current frame between Normalize + * or other scope calls this value can change if the block stack get updated, + * then the block frame is no longer valid. + */ + BlockFrame* CurrentBlockFrame() { + ICHECK(!block_stack_.empty()) << "no block is being built"; + return &block_stack_.back(); + } + + /*! + * \return The current scope frame. + * \note only use this value + */ + ScopeFrame* CurrentScopeFrame() { + ICHECK(!scope_stack_.empty()) << "no scope is being opened"; + return &scope_stack_.back(); + } + + /*! + * \brief Emits an Expr, and returns the variable it is bound to. + * \param expr The Expr to be emitted. + * \param is_dataflow Is the bound variable a DataflowVar or not(i.e. Var). + * \param name_hint Name hint for the bound variable. + * \note This Emit function normalizes the \p expr, + * and performs shape/type deductions by calling Normalize. + * \return The new variable that \p expr is bound to. + */ + Var Emit(Expr expr, bool is_dataflow, String name_hint) { + expr = this->Normalize(expr); + + Var var = CreateVar(is_dataflow, name_hint); + + // set the values + UpdateStructInfo(var, Downcast(expr->struct_info_.value())); + + CurrentBlockFrame()->bindings.push_back(VarBinding(var, expr)); + + // update the binding table + binding_table_[var->vid] = expr; + + return var; + } + + /*! + * \brief Create var for bindings + * \param is_dataflow Is the bound variable a DataflowVar or not(i.e. Var). + * \param name_hint Name hint for the bound variable. + * \return The created var. + */ + Var CreateVar(bool is_dataflow, String name_hint) { + if (name_hint.empty()) { + name_hint = is_dataflow ? "lv" : "gv"; + } + Id vid = Id(name_table_->GetUniqueName(name_hint)); + return is_dataflow ? DataflowVar(vid, /*struct_info_annotation=*/NullOpt) + : Var(vid, /*struct_info_annotation=*/NullOpt); + } + + private: + /*! + * \brief A hashmap to store the mapping of Relax functions and TIR PrimFuncs + * in context_mod to their GlobalVar to avoid generating duplicated functions. + */ + std::unique_ptr> + ctx_func_dedup_map_ = nullptr; + + /*! + * \brief lazily initialize function dedeup map. + */ + void LazyInitCtxFuncDedupMap() { + if (ctx_func_dedup_map_ != nullptr) return; + ctx_func_dedup_map_ = std::make_unique< + std::unordered_map>(); + for (const auto& kv : context_mod_->functions) { + const GlobalVar gv = kv.first; + const BaseFunc func = kv.second; + ctx_func_dedup_map_->emplace(func, gv); + } + } + + // Collect all the variables that a parameter var can define. + // The collector is used to making sure that we record the + // shape vars as defined when calling BeginScope(params) + class StructInfoVarCollector : public StructInfoVisitor { + public: + static Map Collect(const StructInfo& struct_info) { + StructInfoVarCollector collector; + collector(struct_info); + return collector.shape_var_map_; + } + + private: + void VisitStructInfo_(const TensorStructInfoNode* op) final { + if (const auto* shape_expr = op->shape.as()) { + for (const PrimExpr& s : shape_expr->values) { + // Only collect single var defined shape. Ignore something like `R.Tensor((m + 1, n + 1)) + if (const auto* var = s.as()) { + shape_var_map_.Set(GetRef(var), s); + } + } + } + } + + void VisitStructInfo_(const ShapeStructInfoNode* op) final { + for (const PrimExpr& s : op->values.value_or(Array())) { + // Only collect single var defined shape. Ignore something like `R.Tensor((m + 1, n + 1)) + if (const auto* var = s.as()) { + shape_var_map_.Set(GetRef(var), s); + } + } + } + + private: + Map shape_var_map_; + }; +}; + +//--------------------------------------- +// Normalization +//--------------------------------------- +#define RELAX_EXPR_NORMALIZER_LEAF(OP) \ + Expr VisitExpr_(const OP* op) final { return GetRef(op); } + +// TODO(relax-team): Check normalize logic after struct info. + +// Normalizer on struct info: +// +// We take benefit of the following invariants(that are checked in constructor): +// - If an expr appears in StructInfo, then it is already normalized. +// As a result, we do not need to peek into StructInfo in Normalization. +// - Constant, ShapeExpr, already have their StructInfo populated in constructing time. +class Normalizer : public BlockBuilderImpl, private ExprFunctor { + public: + explicit Normalizer(IRModule context_mod) : BlockBuilderImpl(context_mod) {} + + Expr Normalize(const Expr& expr) final { + Expr normalized = this->VisitExpr(expr); + // Invariant: + // After Normalize: an Expr always have + // struct_info (with the exception of Op). + if (!normalized->IsInstance()) { + ICHECK(normalized->struct_info_.defined()) + << "The struct_info_ of an Expr except OpNode after " + "normalization must not be nullptr. However, this Expr does not have struct_info_: " + << normalized; + } + + return normalized; + } + + /*! + * \brief Normalize Argument values to call and other IR sub-fields. + * \param arg The argument. + * \return The normalized value. + * + * \note This function create a new binding for non-leaf expressions except for tuple. + */ + Expr NormalizeArgument(const Expr& arg) final { + // Temp patch to ensure we handle inline PrimFunc case. + // TODO(relax-team) remove such cases from parser and testcases. + if (auto* prim_func = arg.as()) { + return NormalizePrimFunc(GetRef(prim_func)); + } + + if (!block_stack_.empty()) { + // cache lookup + BlockFrame* cur_frame = CurrentBlockFrame(); + auto it = cur_frame->normalize_binding_map.find(arg); + if (it != cur_frame->normalize_binding_map.end()) { + return it->second; + } + } + // skip visit expr's cache, normalize arg + Expr post = ExprFunctor::VisitExpr(arg); + + if (!IsLeafOrTuple(arg)) { + ICHECK(!block_stack_.empty()) << "Cannot normalize non-leaf without a scope"; + Var var = this->Emit(post, ""); + // NOTE: current frame addr can change due to underlying vector + // re-allocation, redo lookup + CurrentBlockFrame()->normalize_binding_map[arg] = var; + return var; + } else { + return post; + } + } + + RELAX_EXPR_NORMALIZER_LEAF(ExternFuncNode); + RELAX_EXPR_NORMALIZER_LEAF(GlobalVarNode); + RELAX_EXPR_NORMALIZER_LEAF(OpNode); + RELAX_EXPR_NORMALIZER_LEAF(ConstantNode); + RELAX_EXPR_NORMALIZER_LEAF(ShapeExprNode); + RELAX_EXPR_NORMALIZER_LEAF(PrimValueNode); + RELAX_EXPR_NORMALIZER_LEAF(StringImmNode); + RELAX_EXPR_NORMALIZER_LEAF(DataTypeImmNode); + + template + Expr VisitVar_(const typename T::ContainerType* var) { + // Parameters and free-vars must be present with struct info + // Other vars must have already been normalized through binding + ICHECK(var->struct_info_.defined()) + << "Var " << var->name_hint() << " does not have struct info."; + return GetRef(var); + } + + Expr VisitExpr_(const VarNode* var) final { return VisitVar_(var); } + + Expr VisitExpr_(const DataflowVarNode* var) final { return VisitVar_(var); } + + // Temp patch to ensure we handle inline PrimFunc case. + // TODO(relax-team) remove such cases from parser and testcases. + Expr NormalizePrimFunc(tir::PrimFunc prim_func) { + if (!prim_func->struct_info_.defined()) { + auto finfo = FuncStructInfo::OpaqueFunc(StructInfoFromType(prim_func->ret_type)); + UpdateStructInfo(prim_func, finfo); + } + return prim_func; + } + + Expr VisitExpr(const Expr& expr) final { + // Temp patch to ensure we handle inline PrimFunc case. + // TODO(relax-team) remove such cases from parser and testcases. + if (auto* prim_func = expr.as()) { + return NormalizePrimFunc(GetRef(prim_func)); + } + + // lookup normalize map + if (!block_stack_.empty()) { + BlockFrame* cur_frame = CurrentBlockFrame(); + auto it = cur_frame->normalize_binding_map.find(expr); + if (it != cur_frame->normalize_binding_map.end()) { + return it->second; + } + } + return ExprFunctor::VisitExpr(expr); + } + + Expr VisitExpr_(const TupleNode* op) final { + bool unchanged = true; + Array new_fields; + + for (const Expr& field : op->fields) { + Expr new_field = this->NormalizeArgument(field); + new_fields.push_back(new_field); + unchanged &= new_field.same_as(field); + } + + Tuple tuple = unchanged ? GetRef(op) : Tuple(new_fields, op->span); + // Update tuple fields. + if (!tuple->struct_info_.defined()) { + Array tuple_sinfo; + for (Expr field : tuple->fields) { + tuple_sinfo.push_back(GetStructInfo(field)); + } + UpdateStructInfo(tuple, TupleStructInfo(tuple_sinfo, op->span)); + } + return tuple; + } + + Expr VisitExpr_(const FunctionNode* op) final { + Expr new_body = this->VisitWithNewScope(op->body, op->params); + + if (new_body.same_as(op->body)) { + return GetRef(op); + } else { + return Function(op->params, new_body, op->ret_struct_info, op->attrs); + } + } + + Expr VisitExpr_(const CallNode* op) final { + Expr new_op = this->NormalizeArgument(op->op); + bool unchanged = new_op.same_as(op->op); + + Array new_args; + + for (Expr arg : op->args) { + Expr new_arg = this->NormalizeArgument(arg); + new_args.push_back(new_arg); + unchanged &= new_arg.same_as(arg); + } + + Call call; + if (unchanged) { + call = GetRef(op); + } else { + call = Call(new_op, new_args, op->attrs, op->sinfo_args); + } + + if (!call->struct_info_.defined()) { + auto inferred_sinfo = InferStructInfo(call); + UpdateStructInfo(call, inferred_sinfo); + } + + return call; + } + + Expr VisitExpr_(const SeqExprNode* op) final { + bool unchanged = true; + Array new_blocks; + for (BindingBlock block : op->blocks) { + BindingBlock new_block = this->VisitBindingBlock(block); + new_blocks.push_back(new_block); + unchanged &= new_block.same_as(block); + } + + this->BeginBindingBlock(); + // the body may not be a leaf expression, so check for that + Expr new_body = this->NormalizeArgument(op->body); + unchanged &= new_body.same_as(op->body); + BindingBlock prologue = this->EndBlock(); + + if (!prologue->bindings.empty()) { + new_blocks.push_back(prologue); + unchanged = false; + } + + // Combine nearby blocks if possible + Array normalized_blocks = NormalizeBlocks(new_blocks); + unchanged &= normalized_blocks.same_as(new_blocks); + + SeqExpr seq_expr; + if (unchanged) { + seq_expr = GetRef(op); + } else { + seq_expr = SeqExpr(normalized_blocks, new_body, op->span); + } + + // only do shape/type inference if the SeqExpr does not have shape/type + if (!seq_expr->struct_info_.defined()) { + UpdateStructInfo(seq_expr, EraseToWellDefinedInScope(GetStructInfo(seq_expr->body))); + } + return seq_expr; + } + + Expr VisitExpr_(const IfNode* op) final { + Expr new_cond = this->NormalizeArgument(op->cond); + Expr new_true = this->VisitWithNewScope(op->true_branch); + Expr new_false = this->VisitWithNewScope(op->false_branch); + + If if_node; + if (new_cond.same_as(op->cond) && new_true.same_as(op->true_branch) && + new_false.same_as(op->false_branch)) { + if_node = GetRef(op); + } else { + if_node = If(new_cond, new_true, new_false, op->span); + } + if (!if_node->struct_info_.defined()) { + auto true_info = EraseToWellDefinedInScope(GetStructInfo(new_true)); + auto false_info = EraseToWellDefinedInScope(GetStructInfo(new_false)); + UpdateStructInfo(if_node, StructInfoLCA(true_info, false_info)); + } + return if_node; + } + + Expr VisitExpr_(const TupleGetItemNode* op) final { + Expr new_tuple = this->NormalizeArgument(op->tuple); + + TupleGetItem node = new_tuple.same_as(op->tuple) ? GetRef(op) + : TupleGetItem(new_tuple, op->index); + + if (!node->struct_info_.defined()) { + auto opt = MatchStructInfo(node->tuple); + ICHECK(opt) << "The struct info of Tuple must be TupleStructInfo."; + UpdateStructInfo(node, opt.value()->fields[node->index]); + } + + return node; + } + + Binding VisitBinding(const Binding& binding) { + if (auto* var_binding = binding.as()) { + return this->VisitVarBinding(GetRef(var_binding)); + } else { + auto* match_cast = binding.as(); + ICHECK(match_cast) << "Unsupported binding type: " << binding->GetTypeKey(); + return this->VisitMatchCast(GetRef(match_cast)); + } + } + + VarBinding VisitVarBinding(VarBinding binding) { + Expr new_value = this->VisitExpr(binding->value); + if (!new_value.same_as(binding->value)) { + binding = VarBinding(binding->var, new_value, binding->span); + } + if (!binding->var->struct_info_.defined()) { + UpdateStructInfo(binding->var, GetStructInfo(new_value)); + } + return binding; + } + + MatchCast VisitMatchCast(MatchCast binding) { + Expr new_value = this->VisitExpr(binding->value); + if (!new_value.same_as(binding->value)) { + binding = MatchCast(binding->var, new_value, binding->struct_info, binding->span); + } + if (!binding->var->struct_info_.defined()) { + UpdateStructInfo(binding->var, binding->struct_info); + } + return binding; + } + + BindingBlock VisitBindingBlock(const BindingBlock& block) { + if (block.as()) { + this->BeginDataflowBlock(); + } else { + this->BeginBindingBlock(); + } + + bool unchanged = true; + for (const Binding& binding : block->bindings) { + Binding new_binding = this->VisitBinding(binding); + unchanged &= new_binding.same_as(binding); + + this->EmitNormalized(new_binding); + } + BindingBlock new_block = this->EndBlock(); + unchanged &= new_block->bindings.size() == block->bindings.size(); + if (unchanged) { + return block; + } + return new_block; + } + + private: + // Helper function to infer the type of a Call. + StructInfo InferStructInfo(const Call& call) { + if (auto* op_ptr = call->op.as()) { + // Case 1: the op field is a primitive op, look up FInferStructInfo attribute + Op op = GetRef(op_ptr); + ICHECK(op_map_infer_struct_info_.count(op)) + << " Cannot find the FInferStructInfo attribute registered to op: " << op->name; + return op_map_infer_struct_info_[op](call, GetRef(this)); + } else { + // derive using function parameters + ICHECK(call->op->struct_info_.defined()); + auto opt = MatchStructInfo(call->op); + ICHECK(opt) << "Call->op must contains a function struct info"; + FuncStructInfo finfo = opt.value(); + return DeriveCallRetStructInfo(finfo, call, GetRef(this), &analyzer_); + } + } + + // erase to well defined within current scope. + StructInfo EraseToWellDefinedInScope(StructInfo info) { + if (scope_stack_.empty()) { + return EraseToWellDefined(info); + } + auto* curr_scope = CurrentScopeFrame(); + auto f_shape_var_map = [curr_scope](tir::Var var) -> Optional { + auto it = curr_scope->shape_var_map.find(var); + if (it != curr_scope->shape_var_map.end()) return (*it).second; + return NullOpt; + }; + return EraseToWellDefined(info, f_shape_var_map); + } + + Expr VisitWithNewScope(const Expr& expr, Optional> params = NullOpt) { + // SeqExpr do not need to prepare for normalization. + if (expr.as()) { + this->BeginScope(params); + Expr ret = this->VisitExpr(expr); + this->EndScope(); + return ret; + } else { + this->BeginScope(params); + + this->BeginBindingBlock(); + Expr post = this->NormalizeArgument(expr); + BindingBlock prologue = this->EndBlock(); + // "New scopes" (function bodies, if/else clauses) must be wrapped in seq exprs. + // Don't wrap if it's already a seq and there are no bindings to add + if (post.as() && prologue->bindings.empty()) { + return post; + } + Array bindings; + if (!prologue->bindings.empty()) { + bindings.push_back(prologue); + } + + SeqExpr seq(bindings, post); + UpdateStructInfo(seq, EraseToWellDefinedInScope(GetStructInfo(seq->body))); + + this->EndScope(); + return seq; + } + } + + Array FlattenBlocks(const Array& blocks) { + // If there is a binding that is a seq expr, split the current block, + // add the nested blocks prior to the seq expr, and bind the seq expr body + // to the var + Array ret; + bool changed = false; + for (const BindingBlock& block : blocks) { + bool is_dataflow = block->IsInstance(); + Array current; + for (const Binding& binding : block->bindings) { + Expr value; + if (const auto* var_binding = binding.as()) { + value = var_binding->value; + } else if (const auto* match_cast = binding.as()) { + value = match_cast->value; + } else { + LOG(FATAL) << "Unknown binding type: " << binding->GetTypeKey(); + } + // if we encounter a nested seq, we have to flatten it: + // 1. Append the binding block we've accumulated so far + // 2. Reset the current block + // 3. Append the inner blocks + // 4. Add a binding of the current var to the seq expr's body to the current block + // then continue + if (auto seq = value.as()) { + changed = true; + ret.push_back(is_dataflow ? DataflowBlock(current) : BindingBlock(current)); + current = {}; + // We do not need to flatten recursively because the normalizer will have normalized + // and thus flattened the inner SeqExprs already + for (const BindingBlock& block : seq->blocks) { + if (is_dataflow && !block->IsInstance()) { + LOG(WARNING) << "Malformed AST: Seq expr nested inside a dataflow block contains a " + "non-dataflow block! " + << seq; + } + ret.push_back(block); + } + + if (const auto* var_binding = binding.as()) { + current.push_back(VarBinding(var_binding->var, seq->body)); + } else if (const auto* match_cast = binding.as()) { + current.push_back(MatchCast(match_cast->var, seq->body, match_cast->struct_info)); + } else { + LOG(FATAL) << "Unknown binding type: " << binding->GetTypeKey(); + } + } else { + current.push_back(binding); + } + } + ret.push_back(is_dataflow ? DataflowBlock(current) : BindingBlock(current)); + } + return changed ? ret : blocks; + } + + Array NormalizeBlocks(const Array& blocks) { + bool changed = false; + Array ret; + auto flattened = FlattenBlocks(blocks); + if (!flattened.same_as(blocks)) { + changed = true; + } + for (const BindingBlock& block : flattened) { + if (block->bindings.empty()) { + // Case 1. Skip empty blocks + changed = true; + } else if (!ret.empty() && ret.back()->type_index() == block->type_index()) { + // Case 2. Merge with previous block if possible + BindingBlock merged; + // NOTE: should check DataflowBlockNode first. + if (const auto* dataflow_block = ret.back().as()) { + auto n = make_object(*dataflow_block); + n->bindings.insert(n->bindings.end(), block->bindings.begin(), block->bindings.end()); + merged = DataflowBlock(n); + } else if (const auto* binding_block = ret.back().as()) { + auto n = make_object(*binding_block); + n->bindings.insert(n->bindings.end(), block->bindings.begin(), block->bindings.end()); + merged = BindingBlock(n); + } else { + LOG(FATAL) << "Unknown block type: " << ret.back()->GetTypeKey(); + } + ret.pop_back(); + ret.push_back(merged); + changed = true; + } else { + // Case 3. Add to the result + ret.push_back(block); + } + } + return changed ? ret : blocks; + } + + /*! \brief Operator struct info inference map. */ + tvm::OpAttrMap op_map_infer_struct_info_ = + Op::GetAttrMap("FInferStructInfo"); +}; + +BlockBuilder BlockBuilder::Create(Optional mod) { + ObjectPtr n = make_object(mod.value_or(IRModule())); + return BlockBuilder(n); +} + +//--------------------------------------- +// User facing function registration. +//--------------------------------------- +TVM_REGISTER_OBJECT_TYPE(BlockBuilderNode); + +TVM_REGISTER_GLOBAL("relax.BlockBuilderCreate").set_body_typed([](Optional mod) { + return BlockBuilder::Create(mod); +}); + +TVM_REGISTER_GLOBAL("relax.BlockBuilderBeginDataflowBlock") + .set_body_method(&BlockBuilderNode::BeginDataflowBlock); + +TVM_REGISTER_GLOBAL("relax.BlockBuilderBeginBindingBlock") + .set_body_method(&BlockBuilderNode::BeginBindingBlock); + +TVM_REGISTER_GLOBAL("relax.BlockBuilderEndBlock") + .set_body_method(&BlockBuilderNode::EndBlock); + +TVM_REGISTER_GLOBAL("relax.BlockBuilderNormalize") + .set_body_method(&BlockBuilderNode::Normalize); + +TVM_REGISTER_GLOBAL("relax.BlockBuilderEmit").set_body_typed([](BlockBuilder builder, Expr expr) { + return builder->Emit(expr); +}); + +TVM_REGISTER_GLOBAL("relax.BlockBuilderEmitMatchCast") + .set_body_typed([](BlockBuilder builder, Expr value, StructInfo struct_info) { + return builder->EmitMatchCast(value, struct_info); + }); + +TVM_REGISTER_GLOBAL("relax.BlockBuilderEmitOutput") + .set_body_typed([](BlockBuilder builder, const Expr& output) { + return builder->EmitOutput(output); + }); + +TVM_REGISTER_GLOBAL("relax.BlockBuilderEmitNormalized") + .set_body_typed([](BlockBuilder builder, Binding binding) { + return builder->EmitNormalized(binding); + }); + +TVM_REGISTER_GLOBAL("relax.BlockBuilderGetUniqueName") + .set_body_typed([](BlockBuilder builder, String name_hint) { + return builder->name_table()->GetUniqueName(name_hint); + }); + +TVM_REGISTER_GLOBAL("relax.BlockBuilderAddFunction") + .set_body_method(&BlockBuilderNode::AddFunction); + +TVM_REGISTER_GLOBAL("relax.BlockBuilderUpdateFunction") + .set_body_method(&BlockBuilderNode::UpdateFunction); + +TVM_REGISTER_GLOBAL("relax.BlockBuilderGetContextIRModule") + .set_body_method(&BlockBuilderNode::GetContextIRModule); + +TVM_REGISTER_GLOBAL("relax.BlockBuilderCurrentBlockIsDataFlow") + .set_body_method(&BlockBuilderNode::CurrentBlockIsDataFlow); + +TVM_REGISTER_GLOBAL("relax.BlockBuilderLookupBinding") + .set_body_method(&BlockBuilderNode::LookupBinding); + +TVM_REGISTER_GLOBAL("relax.BlockBuilderBeginScope") + .set_body_method(&BlockBuilderNode::BeginScope); + +TVM_REGISTER_GLOBAL("relax.BlockBuilderEndScope") + .set_body_method(&BlockBuilderNode::EndScope); +} // namespace relax +} // namespace tvm diff --git a/src/relax/ir/emit_te.cc b/src/relax/ir/emit_te.cc new file mode 100644 index 000000000000..bfb5896c9988 --- /dev/null +++ b/src/relax/ir/emit_te.cc @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file relax/src/ir/emit_te.cc + */ +#include "./emit_te.h" + +#include +#include + +namespace tvm { +namespace relax { + +// RXPlaceholderOpNode +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "rxplaceholder(" << op->name << ", " << op << ")"; + }); + +TVM_REGISTER_NODE_TYPE(RXPlaceholderOpNode); + +te::Tensor TETensor(Expr value, Map tir_var_map, std::string name) { + auto n = make_object(); + n->name = name; + n->value = value; + + // If the value is a constant, it might come as an argument of EmitTE and thus its shape and + // checked-type might not be properly set. In this case we set the shape and dtype of the returned + // TE tensor. + if (const auto* constant = value.as()) { + n->dtype = DataType(constant->data->dtype); + + int ndim = constant->data->ndim; + ShapeTuple shape_tuple = constant->data.Shape(); + Array shape; + shape.reserve(ndim); + for (int i = 0; i < ndim; ++i) { + shape.push_back(IntImm(DataType::Int(64), shape_tuple[i])); + } + n->shape = std::move(shape); + return te::PlaceholderOp(n).output(0); + } + ICHECK(value->struct_info_.defined()) << "value must be normalized and contain StructInfo"; + auto* tensor_sinfo = GetStructInfoAs(value); + ICHECK(tensor_sinfo) << "Value must be a tensor"; + auto* shape_expr = tensor_sinfo->shape.as(); + CHECK(shape_expr) + << "ValueError: Expression does not have an known symbolic shape, please consider use " + "match_cast " + << "to constrain the shape before passing into te_tensor"; + n->shape = shape_expr->values.Map( + [&tir_var_map](const PrimExpr& e) { return tir::Substitute(e, tir_var_map); }); + n->dtype = tensor_sinfo->dtype; + return te::PlaceholderOp(n).output(0); +} + +TVM_REGISTER_GLOBAL("relax.TETensor").set_body_typed(TETensor); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/ir/emit_te.h b/src/relax/ir/emit_te.h new file mode 100644 index 000000000000..46207479c7ef --- /dev/null +++ b/src/relax/ir/emit_te.h @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file relax/src/ir/emit_te.h + * \brief Tensor expression extension in Relax. + */ +#ifndef TVM_RELAX_IR_EMIT_TE_H_ +#define TVM_RELAX_IR_EMIT_TE_H_ + +#include +#include + +#include + +namespace tvm { +namespace relax { + +/*! + * \brief A placeholder op that represents a relax expression. + */ +class RXPlaceholderOpNode : public te::PlaceholderOpNode { + public: + /*! \brief The relax expression. */ + Expr value; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("name", &name); + v->Visit("tag", &tag); + v->Visit("attrs", &attrs); + v->Visit("value", &value); + v->Visit("shape", &shape); + v->Visit("dtype", &dtype); + } + + static constexpr const char* _type_key = "RXPlaceholderOp"; + TVM_DECLARE_FINAL_OBJECT_INFO(RXPlaceholderOpNode, te::PlaceholderOpNode); +}; + +/*! + * \brief Create a TE tensor from relax expression, with TIR variables in the + * tensor shape substituted by the given mapping. + * \param value The relax expression, which is required to have TensorStructInfo. + * \param tir_var_map The mapping to substitute the TIR variables appeared in the + * shape of the input Expr. + * \param name The name of the created tensor. + */ +te::Tensor TETensor(Expr value, Map tir_var_map, std::string name); + +} // namespace relax +} // namespace tvm +#endif // TVM_RELAX_IR_EMIT_TE_H_ diff --git a/src/relax/ir/expr_functor.cc b/src/relax/ir/expr_functor.cc index 048de7950f97..4c4b68f3d200 100644 --- a/src/relax/ir/expr_functor.cc +++ b/src/relax/ir/expr_functor.cc @@ -542,5 +542,249 @@ BindingBlock ExprMutatorBase::VisitBindingBlock(const BindingBlock& block) { PrimExpr ExprMutatorBase::VisitPrimExpr(const PrimExpr& expr) { return expr; } +// ================== +// ExprMutator + +Expr ExprMutator::VisitExpr(const Expr& expr) { + return builder_->Normalize(ExprFunctor::VisitExpr(expr)); +} + +// Visit the use-site of a defined Var +Expr ExprMutator::VisitExpr_(const VarNode* op) { + auto it = var_remap_.find(op->vid); + if (it != var_remap_.end()) { + return it->second; + } + + // default case return self. + return GetRef(op); +} + +// Visit the use-site of a defined DataflowVar +Expr ExprMutator::VisitExpr_(const DataflowVarNode* op) { + auto it = var_remap_.find(op->vid); + if (it != var_remap_.end()) { + return it->second; + } + + // default case return self. + return GetRef(op); +} + +Expr ExprMutator::VisitExpr_(const FunctionNode* op) { + tvm::Array params; + bool all_params_unchanged = true; + for (Var param : op->params) { + Var new_param = this->VisitVarDef(param); + params.push_back(new_param); + all_params_unchanged &= param.same_as(new_param); + } + + Expr body = this->VisitWithNewScope(op->body, params); + + // FuncStructInfo does not depend on Expr + if (all_params_unchanged && body.same_as(op->body)) { + return GetRef(op); + } else { + return Function(params, body, op->ret_struct_info, op->attrs); + } +} + +Expr ExprMutator::VisitExpr_(const IfNode* op) { + Expr guard = this->VisitExpr(op->cond); + Expr true_b = this->VisitWithNewScope(op->true_branch); + Expr false_b = this->VisitWithNewScope(op->false_branch); + if (op->cond.same_as(guard) && op->true_branch.same_as(true_b) && + op->false_branch.same_as(false_b) && + VisitAndCheckStructInfoFieldUnchanged(op->struct_info_)) { + return GetRef(op); + } else { + return If(guard, true_b, false_b, op->span); + } +} + +Expr ExprMutator::VisitExpr_(const SeqExprNode* op) { + bool all_blocks_unchanged = true; + Array blocks; + for (auto block : op->blocks) { + BindingBlock new_block = this->VisitBindingBlock(block); + if (!new_block->bindings.empty()) { + blocks.push_back(new_block); + } + all_blocks_unchanged &= block.same_as(new_block); + } + + builder_->BeginBindingBlock(); + Expr body = this->VisitExpr(op->body); + BindingBlock prologue = builder_->EndBlock(); + if (!prologue->bindings.empty()) { + blocks.push_back(prologue); + all_blocks_unchanged = false; + } + + if (all_blocks_unchanged && body.same_as(op->body) && + VisitAndCheckStructInfoFieldUnchanged(op->struct_info_)) { + return GetRef(op); + } else { + return SeqExpr(blocks, body); + } +} + +RELAX_VAR_BINDING_DISPATCH_IMPL(ExprMutator); +RELAX_EXPR_MUTATOR_VISIT_BINDING_IMPL(ConstantNode); +RELAX_EXPR_MUTATOR_VISIT_BINDING_IMPL(TupleNode); +RELAX_EXPR_MUTATOR_VISIT_BINDING_IMPL(VarNode); +RELAX_EXPR_MUTATOR_VISIT_BINDING_IMPL(DataflowVarNode); +RELAX_EXPR_MUTATOR_VISIT_BINDING_IMPL(ShapeExprNode); +RELAX_EXPR_MUTATOR_VISIT_BINDING_IMPL(ExternFuncNode); +RELAX_EXPR_MUTATOR_VISIT_BINDING_IMPL(GlobalVarNode); +RELAX_EXPR_MUTATOR_VISIT_BINDING_IMPL(FunctionNode); +RELAX_EXPR_MUTATOR_VISIT_BINDING_IMPL(CallNode); +RELAX_EXPR_MUTATOR_VISIT_BINDING_IMPL(SeqExprNode); +RELAX_EXPR_MUTATOR_VISIT_BINDING_IMPL(IfNode); +RELAX_EXPR_MUTATOR_VISIT_BINDING_IMPL(OpNode); +RELAX_EXPR_MUTATOR_VISIT_BINDING_IMPL(TupleGetItemNode); +RELAX_EXPR_MUTATOR_VISIT_BINDING_IMPL(PrimValueNode); +RELAX_EXPR_MUTATOR_VISIT_BINDING_IMPL(StringImmNode); +RELAX_EXPR_MUTATOR_VISIT_BINDING_IMPL(DataTypeImmNode); + +void ExprMutator::ReEmitBinding(const VarBindingNode* binding, Expr new_value) { + Var new_var = this->VisitVarDef(binding->var); + + // fast path: reemit binding if nothing changes + if (new_var.same_as(binding->var) && new_value.same_as(binding->value)) { + builder_->EmitNormalized(GetRef(binding)); + return; + } + + Var temp = WithStructInfo(new_var, GetStructInfo(new_value)); + if (!temp.same_as(new_var)) { + new_var = temp; + this->var_remap_[binding->var->vid] = new_var; + } + + builder_->EmitNormalized(VarBinding(new_var, new_value)); +} + +void ExprMutator::VisitBinding_(const MatchCastNode* binding) { + Var new_var = this->VisitVarDef(binding->var); + Expr new_value = this->VisitExpr(binding->value); + + // re-emit old binding if nothing changes + if (new_var.same_as(binding->var) && new_value.same_as(binding->value)) { + builder_->EmitNormalized(GetRef(binding)); + } else { + new_value = builder_->NormalizeArgument(new_value); + builder_->EmitNormalized(MatchCast(new_var, new_value, binding->struct_info, binding->span)); + } +} + +BindingBlock ExprMutator::VisitBindingBlock_(const BindingBlockNode* block) { + builder_->BeginBindingBlock(); + for (Binding binding : block->bindings) { + this->VisitBinding(binding); + } + return builder_->EndBlock(); +} + +BindingBlock ExprMutator::VisitBindingBlock_(const DataflowBlockNode* block) { + builder_->BeginDataflowBlock(); + for (auto binding : block->bindings) { + this->VisitBinding(binding); + } + return builder_->EndBlock(); +} + +Var ExprMutator::VisitVarDef_(const DataflowVarNode* var) { + if (auto* sinfo = var->struct_info_.as()) { + StructInfo struct_info = this->VisitExprDepStructInfoField(GetRef(sinfo)); + if (struct_info.same_as(var->struct_info_)) { + return GetRef(var); + } else { + return DataflowVar(var->vid, struct_info, var->span); + } + } else { + return GetRef(var); + } +} + +Var ExprMutator::VisitVarDef_(const VarNode* var) { + if (auto* sinfo = var->struct_info_.as()) { + StructInfo struct_info = this->VisitExprDepStructInfoField(GetRef(sinfo)); + if (struct_info.same_as(var->struct_info_)) { + return GetRef(var); + } else { + return Var(var->vid, struct_info, var->span); + } + } else { + return GetRef(var); + } +} + +void ExprMutator::VisitBinding(const Binding& binding) { + if (const auto* node = binding.as()) { + VisitBinding_(node); + } else if (const auto* node = binding.as()) { + VisitBinding_(node); + } else { + LOG(FATAL) << "TypeError: Invalid type: " << binding->GetTypeKey(); + } +} + +BindingBlock ExprMutator::VisitBindingBlock(const BindingBlock& block) { + BindingBlock ret; + if (const auto* node = block.as()) { + ret = VisitBindingBlock_(node); + } else if (const auto* node = block.as()) { + ret = VisitBindingBlock_(node); + } else { + LOG(FATAL) << "TypeError: Invalid type: " << block->GetTypeKey(); + } + return ret; +} + +Var ExprMutator::VisitVarDef(const Var& var) { + Var ret; + if (const auto* node = var.as()) { + ret = VisitVarDef_(node); + } else if (const auto* node = var.as()) { + ret = VisitVarDef_(node); + } else { + LOG(FATAL) << "TypeError: Invalid type: " << var->GetTypeKey(); + } + return ret; +} + +Expr ExprMutator::VisitWithNewScope(const Expr& expr, Optional> params) { + ICHECK(expr->IsInstance()) + << "Normal form requires all new scope is stored as SeqExpr"; + builder_->BeginScope(params); + Expr ret = this->VisitExpr(expr); + builder_->EndScope(); + return ret; +} + +Optional ExprMutator::LookupBinding(const Var& var) { return builder_->LookupBinding(var); } + +Var ExprMutator::WithStructInfo(Var var, StructInfo struct_info) { + ICHECK(struct_info.defined()); + + // TODO(relax-team) add StructInfoEqual check + if (var->struct_info_.defined()) { + // use same-as as a quick path + if (var->struct_info_.same_as(struct_info) || + StructuralEqual()(var->struct_info_, struct_info)) { + return var; + } else { + Var new_var = var.as() ? DataflowVar(var->vid, struct_info, var->span) + : Var(var->vid, struct_info, var->span); + return new_var; + } + } else { + UpdateStructInfo(var, struct_info); + return var; + } +} + } // namespace relax } // namespace tvm diff --git a/src/relax/ir/py_expr_functor.cc b/src/relax/ir/py_expr_functor.cc new file mode 100644 index 000000000000..7e86235aa61e --- /dev/null +++ b/src/relax/ir/py_expr_functor.cc @@ -0,0 +1,649 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relax/py_expr_functor.cc + * \brief The backbone of PyExprVisitor/PyExprMutator. + */ +#include + +namespace tvm { +namespace relax { + +/*! + * \brief The abstract interface of ExprVisitor. + */ +class PyExprVisitorNode : public Object, public ExprVisitor { + private: + using TSelf = PyExprVisitorNode; + using FType = tvm::NodeFunctor; + + public: + /*! \brief The packed function to the `VisitExpr(const Expr& expr)` function. */ + PackedFunc f_visit_expr{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const ConstantNode* op)` function. */ + PackedFunc f_visit_constant_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const TupleNode* op)` function. */ + PackedFunc f_visit_tuple_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const VarNode* op)` function. */ + PackedFunc f_visit_var_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const DataflowVarNode* op)` function. */ + PackedFunc f_visit_dataflow_var_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const ShapeExprNode* op)` function. */ + PackedFunc f_visit_shape_expr_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const ExternFuncNode* op)` function. */ + PackedFunc f_visit_extern_func_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const GlobalVarNode* op)` function. */ + PackedFunc f_visit_global_var_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const FunctionNode* op)` function. */ + PackedFunc f_visit_function_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const CallNode* op)` function. */ + PackedFunc f_visit_call_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const SeqExprNode* op)` function. */ + PackedFunc f_visit_seq_expr_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const IfNode* op)` function. */ + PackedFunc f_visit_if_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const OpNode* op)` function. */ + PackedFunc f_visit_op_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const TupleGetItemNode* op)` function. */ + PackedFunc f_visit_tuple_getitem_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const PrimValueNode* op)` function. */ + PackedFunc f_visit_prim_value_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const StringImmNode* op)` function. */ + PackedFunc f_visit_string_imm_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const DataTypeImmNode* op)` function. */ + PackedFunc f_visit_data_type_imm_{nullptr}; + /*! \brief The packed function to the `VisitBinding(const Binding& binding)` function. */ + PackedFunc f_visit_binding{nullptr}; + /*! \brief The packed function to the `VisitBinding_(const VarBindingNode* binding)` + * function. */ + PackedFunc f_visit_var_binding_{nullptr}; + /*! \brief The packed function to the `VisitBinding_(const MatchCastNode* binding)` + * function. */ + PackedFunc f_visit_match_cast_{nullptr}; + /*! \brief The packed function to the `VisitBindingBlock(const BindingBlock& block)` + * function. */ + PackedFunc f_visit_binding_block{nullptr}; + /*! \brief The packed function to the `VisitBindingBlock_(const BindingBlockNode* block)` + * function. */ + PackedFunc f_visit_binding_block_{nullptr}; + /*! \brief The packed function to the `VisitBindingBlock_(const DataflowBlockNode* block)` + * function. */ + PackedFunc f_visit_dataflow_block_{nullptr}; + /*! \brief The packed function to the `VisitVarDef(const Var& var)` function. */ + PackedFunc f_visit_var_def{nullptr}; + /*! \brief The packed function to the `VisitVarDef_(const VarNode* var)` function. */ + PackedFunc f_visit_var_def_{nullptr}; + /*! \brief The packed function to the `VisitVarDef_(const DataflowVarNode* var)` function. */ + PackedFunc f_visit_dataflow_var_def_{nullptr}; + /*! \brief The packed function to the `VisitSpan(const Span& span)` function. */ + PackedFunc f_visit_span{nullptr}; + + void VisitExpr(const Expr& expr) { + if (f_visit_expr != nullptr) { + f_visit_expr(expr); + } else { + // Need to init the overwrite VTable + static FType vtable = InitVTable(); + vtable(expr, this); + } + } + + void VisitBinding(const Binding& binding) + PY_EXPR_VISITOR_DEFAULT(binding, f_visit_binding, ExprVisitor::VisitBinding(binding)); + + void VisitBinding_(const VarBindingNode* binding) + PY_EXPR_VISITOR_DEFAULT(GetRef(binding), f_visit_var_binding_, + ExprVisitor::VisitBinding_(binding)); + void VisitBinding_(const MatchCastNode* binding) + PY_EXPR_VISITOR_DEFAULT(GetRef(binding), f_visit_match_cast_, + ExprVisitor::VisitBinding_(binding)); + + void VisitBindingBlock(const BindingBlock& block) + PY_EXPR_VISITOR_DEFAULT(block, f_visit_binding_block, ExprVisitor::VisitBindingBlock(block)); + + void VisitBindingBlock_(const BindingBlockNode* block) + PY_EXPR_VISITOR_DEFAULT(GetRef(block), f_visit_binding_block_, + ExprVisitor::VisitBindingBlock_(block)); + void VisitBindingBlock_(const DataflowBlockNode* block) + PY_EXPR_VISITOR_DEFAULT(GetRef(block), f_visit_dataflow_block_, + ExprVisitor::VisitBindingBlock_(block)); + + void VisitVarDef(const Var& var) + PY_EXPR_VISITOR_DEFAULT(var, f_visit_var_def, ExprVisitor::VisitVarDef(var)); + void VisitVarDef_(const VarNode* var) + PY_EXPR_VISITOR_DEFAULT(GetRef(var), f_visit_var_def_, ExprVisitor::VisitVarDef_(var)); + void VisitVarDef_(const DataflowVarNode* var) + PY_EXPR_VISITOR_DEFAULT(GetRef(var), f_visit_dataflow_var_def_, + ExprVisitor::VisitVarDef_(var)); + + void VisitSpan(const Span& span) + PY_EXPR_VISITOR_DEFAULT(span, f_visit_span, ExprVisitor::VisitSpan(span)); + + void VisitAttrs(AttrVisitor* v) {} + static constexpr const char* _type_key = "expr_functor.PyExprVisitor"; + TVM_DECLARE_BASE_OBJECT_INFO(PyExprVisitorNode, Object); + + private: + // initialize the vtable. + static FType InitVTable() { + FType vtable; + // Set dispatch + PY_EXPR_VISITOR_DISPATCH(ConstantNode, f_visit_constant_); + PY_EXPR_VISITOR_DISPATCH(TupleNode, f_visit_tuple_); + PY_EXPR_VISITOR_DISPATCH(VarNode, f_visit_var_); + PY_EXPR_VISITOR_DISPATCH(DataflowVarNode, f_visit_dataflow_var_); + PY_EXPR_VISITOR_DISPATCH(ShapeExprNode, f_visit_shape_expr_); + PY_EXPR_VISITOR_DISPATCH(ExternFuncNode, f_visit_extern_func_); + PY_EXPR_VISITOR_DISPATCH(GlobalVarNode, f_visit_global_var_); + PY_EXPR_VISITOR_DISPATCH(FunctionNode, f_visit_function_); + PY_EXPR_VISITOR_DISPATCH(CallNode, f_visit_call_); + PY_EXPR_VISITOR_DISPATCH(SeqExprNode, f_visit_seq_expr_); + PY_EXPR_VISITOR_DISPATCH(IfNode, f_visit_if_); + PY_EXPR_VISITOR_DISPATCH(OpNode, f_visit_op_); + PY_EXPR_VISITOR_DISPATCH(TupleGetItemNode, f_visit_tuple_getitem_); + PY_EXPR_VISITOR_DISPATCH(PrimValueNode, f_visit_prim_value_); + PY_EXPR_VISITOR_DISPATCH(StringImmNode, f_visit_string_imm_); + PY_EXPR_VISITOR_DISPATCH(DataTypeImmNode, f_visit_data_type_imm_); + return vtable; + } +}; + +TVM_REGISTER_NODE_TYPE(PyExprVisitorNode); + +/*! + * \brief Managed reference to PyExprVisitorNode. + * \sa PyExprVisitorNode + */ +class PyExprVisitor : public ObjectRef { + public: + /*! + * \brief Create a PyExprVisitor with customized methods on the python-side. + * \param f_visit_expr The packed function of `VisitExpr(const Expr& expr)`. + * \param f_visit_constant_ The packed function of `VisitExpr_(const ConstantNode* op)`. + * \param f_visit_tuple_ The packed function of `VisitExpr_(const TupleNode* op)`. + * \param f_visit_var_ The packed function of `VisitExpr_(const VarNode* op)`. + * \param f_visit_dataflow_var_ The packed function of `VisitExpr_(const DataflowVarNode* op)`. + * \param f_visit_shape_expr_ The packed function of `VisitExpr_(const ShapeExprNode* op)`. + * \param f_visit_extern_func_ The packed function of `VisitExpr_(const ExternFuncNode* op)`. + * \param f_visit_global_var_ The packed function of `VisitExpr_(const GlobalVarNode* op)`. + * \param f_visit_function_ The packed function of `VisitExpr_(const FunctionNode* op)`. + * \param f_visit_call_ The packed function of `VisitExpr_(const CallNode* op)`. + * \param f_visit_seq_expr_ The packed function of `VisitExpr_(const SeqExprNode* op)`. + * \param f_visit_if_ The packed function of `VisitExpr_(const IfNode* op)`. + * \param f_visit_op_ The packed function of `VisitExpr_(const OpNode* op)`. + * \param f_visit_tuple_getitem_ The packed function of `VisitExpr_(const TupleGetItemNode* op)`. + * \param f_visit_prim_value_ The packed function of `VisitExpr_(const PrimValueNode* op)`. + * \param f_visit_string_imm_ The packed function of `VisitExpr_(const StringImmNode* op)`. + * \param f_visit_data_type_imm_ The packed function of `VisitExpr_(const DataTypeImmNode* op)`. + * \param f_visit_binding The packed function of `VisitBinding(const Binding& binding)`. + * \param f_visit_var_binding_ The packed function of `VisitBinding_(const VarBindingNode* + * binding)`. + * \param f_visit_match_cast_ The packed function of `VisitBinding_(const MatchCastNode* + * binding)`. + * \param f_visit_binding_block The packed function of `VisitBindingBlock(const BindingBlock& + * block)`. + * \param f_visit_binding_block_ The packed function of `VisitBindingBlock_(const + * BindingBlockNode* block)`. + * \param f_visit_dataflow_block_ The packed function of `VisitBindingBlock_(const + * DataflowBlockNode* block)`. + * \param f_visit_var_def The packed function of `VisitVarDef(const Var& var)`. + * \param f_visit_var_def_ The packed function of `VisitVarDef_(const VarNode* var)`. + * \param f_visit_dataflow_var_def_ The packed function of `VisitVarDef_(const DataflowVarNode* + * var)`. + * \param f_visit_span The packed function of `VisitSpan(const Span& span)`. + * \return The PyVisitor created. + */ + TVM_DLL static PyExprVisitor MakePyExprVisitor( + PackedFunc f_visit_expr, PackedFunc f_visit_constant_, PackedFunc f_visit_tuple_, + PackedFunc f_visit_var_, PackedFunc f_visit_dataflow_var_, PackedFunc f_visit_shape_expr_, + PackedFunc f_visit_extern_func_, PackedFunc f_visit_global_var_, PackedFunc f_visit_function_, + PackedFunc f_visit_call_, PackedFunc f_visit_seq_expr_, PackedFunc f_visit_if_, + PackedFunc f_visit_op_, PackedFunc f_visit_tuple_getitem_, PackedFunc f_visit_prim_value_, + PackedFunc f_visit_string_imm_, PackedFunc f_visit_data_type_imm_, PackedFunc f_visit_binding, + PackedFunc f_visit_var_binding_, PackedFunc f_visit_match_cast_, + PackedFunc f_visit_binding_block, PackedFunc f_visit_binding_block_, + PackedFunc f_visit_dataflow_block_, PackedFunc f_visit_var_def, PackedFunc f_visit_var_def_, + PackedFunc f_visit_dataflow_var_def_, PackedFunc f_visit_span) { + ObjectPtr n = make_object(); + n->f_visit_expr = f_visit_expr; + n->f_visit_binding = f_visit_binding; + n->f_visit_binding_block = f_visit_binding_block; + n->f_visit_var_def = f_visit_var_def; + n->f_visit_span = f_visit_span; + n->f_visit_constant_ = f_visit_constant_; + n->f_visit_tuple_ = f_visit_tuple_; + n->f_visit_var_ = f_visit_var_; + n->f_visit_dataflow_var_ = f_visit_dataflow_var_; + n->f_visit_shape_expr_ = f_visit_shape_expr_; + n->f_visit_extern_func_ = f_visit_extern_func_; + n->f_visit_global_var_ = f_visit_global_var_; + n->f_visit_function_ = f_visit_function_; + n->f_visit_call_ = f_visit_call_; + n->f_visit_seq_expr_ = f_visit_seq_expr_; + n->f_visit_if_ = f_visit_if_; + n->f_visit_op_ = f_visit_op_; + n->f_visit_tuple_getitem_ = f_visit_tuple_getitem_; + n->f_visit_prim_value_ = f_visit_prim_value_; + n->f_visit_string_imm_ = f_visit_string_imm_; + n->f_visit_data_type_imm_ = f_visit_data_type_imm_; + n->f_visit_var_binding_ = f_visit_var_binding_; + n->f_visit_match_cast_ = f_visit_match_cast_; + n->f_visit_binding_block_ = f_visit_binding_block_; + n->f_visit_dataflow_block_ = f_visit_dataflow_block_; + n->f_visit_var_def_ = f_visit_var_def_; + n->f_visit_dataflow_var_def_ = f_visit_dataflow_var_def_; + return PyExprVisitor(n); + } + + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(PyExprVisitor, ObjectRef, PyExprVisitorNode); +}; + +/*! + * \brief The abstract interface of ExprMutator. + */ +class PyExprMutatorNode : public Object, public ExprMutator { + private: + using TSelf = PyExprMutatorNode; + using FType = tvm::NodeFunctor; + + public: + /*! \brief The packed function to the `VisitExpr(const Expr& expr)` function. */ + PackedFunc f_visit_expr{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const ConstantNode* op)` function. */ + PackedFunc f_visit_constant_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const TupleNode* op)` function. */ + PackedFunc f_visit_tuple_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const VarNode* op)` function. */ + PackedFunc f_visit_var_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const DataflowVarNode* op)` function. */ + PackedFunc f_visit_dataflow_var_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const ShapeExprNode* op)` function. */ + PackedFunc f_visit_shape_expr_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const ExternFuncNode* op)` function. */ + PackedFunc f_visit_extern_func_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const GlobalVarNode* op)` function. */ + PackedFunc f_visit_global_var_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const FunctionNode* op)` function. */ + PackedFunc f_visit_function_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const CallNode* op)` function. */ + PackedFunc f_visit_call_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const SeqExprNode* op)` function. */ + PackedFunc f_visit_seq_expr_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const IfNode* op)` function. */ + PackedFunc f_visit_if_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const OpNode* op)` function. */ + PackedFunc f_visit_op_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const TupleGetItemNode* op)` function. */ + PackedFunc f_visit_tuple_getitem_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const PrimValueNode* op)` function. */ + PackedFunc f_visit_prim_value_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const StringImmNode* op)` function. */ + PackedFunc f_visit_string_imm_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const DataTypeImmNode* op)` function. */ + PackedFunc f_visit_data_type_imm_{nullptr}; + /*! \brief The packed function to the `VisitBinding(const Binding& binding)` function. */ + PackedFunc f_visit_binding{nullptr}; + /*! \brief The packed function to the `VisitBinding_(const VarBindingNode* binding)` + * function. */ + PackedFunc f_visit_var_binding_{nullptr}; + /*! \brief The packed function to the `VisitBinding_(const MatchCastNode* binding)` + * function. */ + PackedFunc f_visit_match_cast_{nullptr}; + /*! \brief The packed function to the `VisitBindingBlock(const BindingBlock& block)` + * function. */ + PackedFunc f_visit_binding_block{nullptr}; + /*! \brief The packed function to the `VisitBindingBlock_(const BindingBlockNode* block)` + * function. */ + PackedFunc f_visit_binding_block_{nullptr}; + /*! \brief The packed function to the `VisitBindingBlock_(const DataflowBlockNode* block)` + * function. */ + PackedFunc f_visit_dataflow_block_{nullptr}; + /*! \brief The packed function to the `VisitVarDef(const Var& var)` function. */ + PackedFunc f_visit_var_def{nullptr}; + /*! \brief The packed function to the `VisitVarDef_(const VarNode* var)` function. */ + PackedFunc f_visit_var_def_{nullptr}; + /*! \brief The packed function to the `VisitVarDef_(const DataflowVarNode* var)` function. */ + PackedFunc f_visit_dataflow_var_def_{nullptr}; + /*! \brief The packed function to the `VisitSpan(const Span& span)` function. */ + PackedFunc f_visit_span{nullptr}; + + Expr VisitExpr(const Expr& expr) { + if (f_visit_expr != nullptr) { + return builder_->Normalize(f_visit_expr(expr)); + } else { + static FType vtable = InitVTable(); + return builder_->Normalize(vtable(expr, this)); + } + } + + void VisitBinding(const Binding& binding) { + if (f_visit_binding != nullptr) + f_visit_binding(binding); + else + ExprMutator::VisitBinding(binding); + } + + void VisitBinding_(const VarBindingNode* binding) { + if (f_visit_var_binding_ != nullptr) + f_visit_var_binding_(GetRef(binding)); + else + ExprMutator::VisitBinding_(binding); + } + + void VisitBinding_(const MatchCastNode* binding) { + if (f_visit_match_cast_ != nullptr) + f_visit_match_cast_(GetRef(binding)); + else + ExprMutator::VisitBinding_(binding); + } + + BindingBlock VisitBindingBlock(const BindingBlock& block) + PY_EXPR_MUTATOR_DEFAULT(block, f_visit_binding_block, ExprMutator::VisitBindingBlock(block), + BindingBlock); + + BindingBlock VisitBindingBlock_(const BindingBlockNode* block) + PY_EXPR_MUTATOR_DEFAULT(GetRef(block), f_visit_binding_block_, + ExprMutator::VisitBindingBlock_(block), BindingBlock); + BindingBlock VisitBindingBlock_(const DataflowBlockNode* block) + PY_EXPR_MUTATOR_DEFAULT(GetRef(block), f_visit_dataflow_block_, + ExprMutator::VisitBindingBlock_(block), BindingBlock); + + Var VisitVarDef(const Var& var) + PY_EXPR_MUTATOR_DEFAULT(var, f_visit_var_def, ExprMutator::VisitVarDef(var), Var); + Var VisitVarDef_(const VarNode* var) PY_EXPR_MUTATOR_DEFAULT(GetRef(var), f_visit_var_def_, + ExprMutator::VisitVarDef_(var), Var); + Var VisitVarDef_(const DataflowVarNode* var) + PY_EXPR_MUTATOR_DEFAULT(GetRef(var), f_visit_dataflow_var_def_, + ExprMutator::VisitVarDef_(var), Var); + + /*! + * \brief Dispatcher for post-order rewrite. + * \param expr The Expr to be rewritten. + * \return The Expr after post-order rewritten. + */ + Expr VisitExprPostOrder(const Expr& expr) { + static FType post_order_vtable = InitPostOrderVTable(); + return post_order_vtable(expr, this); + } + + using ExprMutator::builder_; + using ExprMutator::LookupBinding; + using ExprMutator::var_remap_; + using ExprMutator::VisitWithNewScope; + using ExprMutator::WithStructInfo; + + void VisitAttrs(AttrVisitor* v) { v->Visit("builder_", &builder_); } + static constexpr const char* _type_key = "expr_functor.PyExprMutator"; + TVM_DECLARE_BASE_OBJECT_INFO(PyExprMutatorNode, Object); + + private: + // initialize the vtable. + static FType InitVTable() { + FType vtable; + // Set dispatch + PY_EXPR_MUTATOR_DISPATCH(ConstantNode, f_visit_constant_); + PY_EXPR_MUTATOR_DISPATCH(TupleNode, f_visit_tuple_); + PY_EXPR_MUTATOR_DISPATCH(VarNode, f_visit_var_); + PY_EXPR_MUTATOR_DISPATCH(DataflowVarNode, f_visit_dataflow_var_); + PY_EXPR_MUTATOR_DISPATCH(ShapeExprNode, f_visit_shape_expr_); + PY_EXPR_MUTATOR_DISPATCH(ExternFuncNode, f_visit_extern_func_); + PY_EXPR_MUTATOR_DISPATCH(GlobalVarNode, f_visit_global_var_); + PY_EXPR_MUTATOR_DISPATCH(FunctionNode, f_visit_function_); + PY_EXPR_MUTATOR_DISPATCH(CallNode, f_visit_call_); + PY_EXPR_MUTATOR_DISPATCH(SeqExprNode, f_visit_seq_expr_); + PY_EXPR_MUTATOR_DISPATCH(IfNode, f_visit_if_); + PY_EXPR_MUTATOR_DISPATCH(OpNode, f_visit_op_); + PY_EXPR_MUTATOR_DISPATCH(TupleGetItemNode, f_visit_tuple_getitem_); + PY_EXPR_MUTATOR_DISPATCH(PrimValueNode, f_visit_prim_value_); + PY_EXPR_MUTATOR_DISPATCH(StringImmNode, f_visit_string_imm_); + PY_EXPR_MUTATOR_DISPATCH(DataTypeImmNode, f_visit_data_type_imm_); + return vtable; + } + + // initialize the vtable for post order visit. + static FType InitPostOrderVTable() { + FType post_order_vtable; + // Set dispatch + PY_EXPR_MUTATOR_VISIT_EXPR_POST_ORDER_DISPATCH(ConstantNode); + PY_EXPR_MUTATOR_VISIT_EXPR_POST_ORDER_DISPATCH(TupleNode); + PY_EXPR_MUTATOR_VISIT_EXPR_POST_ORDER_DISPATCH(VarNode); + PY_EXPR_MUTATOR_VISIT_EXPR_POST_ORDER_DISPATCH(DataflowVarNode); + PY_EXPR_MUTATOR_VISIT_EXPR_POST_ORDER_DISPATCH(ShapeExprNode); + PY_EXPR_MUTATOR_VISIT_EXPR_POST_ORDER_DISPATCH(ExternFuncNode); + PY_EXPR_MUTATOR_VISIT_EXPR_POST_ORDER_DISPATCH(GlobalVarNode); + PY_EXPR_MUTATOR_VISIT_EXPR_POST_ORDER_DISPATCH(FunctionNode); + PY_EXPR_MUTATOR_VISIT_EXPR_POST_ORDER_DISPATCH(CallNode); + PY_EXPR_MUTATOR_VISIT_EXPR_POST_ORDER_DISPATCH(SeqExprNode); + PY_EXPR_MUTATOR_VISIT_EXPR_POST_ORDER_DISPATCH(IfNode); + PY_EXPR_MUTATOR_VISIT_EXPR_POST_ORDER_DISPATCH(OpNode); + PY_EXPR_MUTATOR_VISIT_EXPR_POST_ORDER_DISPATCH(TupleGetItemNode); + PY_EXPR_MUTATOR_VISIT_EXPR_POST_ORDER_DISPATCH(PrimValueNode); + PY_EXPR_MUTATOR_VISIT_EXPR_POST_ORDER_DISPATCH(StringImmNode); + PY_EXPR_MUTATOR_VISIT_EXPR_POST_ORDER_DISPATCH(DataTypeImmNode); + return post_order_vtable; + } +}; + +TVM_REGISTER_NODE_TYPE(PyExprMutatorNode); + +/*! + * \brief Managed reference to PyExprMutatorNode. + * \sa PyExprMutatorNode + */ +class PyExprMutator : public ObjectRef { + public: + /*! + * \brief Create a PyExprMutator with customized methods on the python-side. + * \param f_visit_expr The packed function of `VisitExpr(const Expr& expr)`. + * \param f_visit_constant_ The packed function of `VisitExpr_(const ConstantNode* op)`. + * \param f_visit_tuple_ The packed function of `VisitExpr_(const TupleNode* op)`. + * \param f_visit_var_ The packed function of `VisitExpr_(const VarNode* op)`. + * \param f_visit_dataflow_var_ The packed function of `VisitExpr_(const DataflowVarNode* op)`. + * \param f_visit_shape_expr_ The packed function of `VisitExpr_(const ShapeExprNode* op)`. + * \param f_visit_extern_func_ The packed function of `VisitExpr_(const ExternFuncNode* op)`. + * \param f_visit_global_var_ The packed function of `VisitExpr_(const GlobalVarNode* op)`. + * \param f_visit_function_ The packed function of `VisitExpr_(const FunctionNode* op)`. + * \param f_visit_call_ The packed function of `VisitExpr_(const CallNode* op)`. + * \param f_visit_seq_expr_ The packed function of `VisitExpr_(const SeqExprNode* op)`. + * \param f_visit_if_ The packed function of `VisitExpr_(const IfNode* op)`. + * \param f_visit_op_ The packed function of `VisitExpr_(const OpNode* op)`. + * \param f_visit_tuple_getitem_ The packed function of `VisitExpr_(const TupleGetItemNode* op)`. + * \param f_visit_prim_value_ The packed function of `VisitExpr_(const PrimValueNode* op)`. + * \param f_visit_string_imm_ The packed function of `VisitExpr_(const StringImmNode* op)`. + * \param f_visit_data_type_imm_ The packed function of `VisitExpr_(const DataTypeImmNode* op)`. + * \param f_visit_binding The packed function of `VisitBinding(const Binding& binding)`. + * \param f_visit_var_binding_ The packed function of `VisitBinding_(const VarBindingNode* + * binding)`. + * \param f_visit_match_cast_ The packed function of `VisitBinding_(const MatchCastNode* + * binding)`. + * \param f_visit_binding_block The packed function of `VisitBindingBlock(const BindingBlock& + * block)`. + * \param f_visit_binding_block_ The packed function of `VisitBindingBlock_(const + * BindingBlockNode* block)`. + * \param f_visit_dataflow_block_ The packed function of `VisitBindingBlock_(const + * DataflowBlockNode* block)`. + * \param f_visit_var_def The packed function of `VisitVarDef(const Var& var)`. + * \param f_visit_var_def_ The packed function of `VisitVarDef_(const VarNode* var)`. + * \param f_visit_dataflow_var_def_ The packed function of `VisitVarDef_(const DataflowVarNode* + * var)`. + * \param f_visit_span The packed function of `VisitSpan(const Span& span)`. + * \return The PyExprMutator created. + */ + TVM_DLL static PyExprMutator MakePyExprMutator( + BlockBuilder builder_, PackedFunc f_visit_expr, PackedFunc f_visit_constant_, + PackedFunc f_visit_tuple_, PackedFunc f_visit_var_, PackedFunc f_visit_dataflow_var_, + PackedFunc f_visit_shape_expr_, PackedFunc f_visit_extern_func_, + PackedFunc f_visit_global_var_, PackedFunc f_visit_function_, PackedFunc f_visit_call_, + PackedFunc f_visit_seq_expr_, PackedFunc f_visit_if_, PackedFunc f_visit_op_, + PackedFunc f_visit_tuple_getitem_, PackedFunc f_visit_prim_value_, + PackedFunc f_visit_string_imm_, PackedFunc f_visit_data_type_imm_, PackedFunc f_visit_binding, + PackedFunc f_visit_var_binding_, PackedFunc f_visit_match_cast_, + PackedFunc f_visit_binding_block, PackedFunc f_visit_binding_block_, + PackedFunc f_visit_dataflow_block_, PackedFunc f_visit_var_def, PackedFunc f_visit_var_def_, + PackedFunc f_visit_dataflow_var_def_, PackedFunc f_visit_span) { + ObjectPtr n = make_object(); + n->builder_ = builder_; + n->f_visit_expr = f_visit_expr; + n->f_visit_constant_ = f_visit_constant_; + n->f_visit_tuple_ = f_visit_tuple_; + n->f_visit_var_ = f_visit_var_; + n->f_visit_dataflow_var_ = f_visit_dataflow_var_; + n->f_visit_shape_expr_ = f_visit_shape_expr_; + n->f_visit_extern_func_ = f_visit_extern_func_; + n->f_visit_global_var_ = f_visit_global_var_; + n->f_visit_function_ = f_visit_function_; + n->f_visit_call_ = f_visit_call_; + n->f_visit_seq_expr_ = f_visit_seq_expr_; + n->f_visit_if_ = f_visit_if_; + n->f_visit_op_ = f_visit_op_; + n->f_visit_tuple_getitem_ = f_visit_tuple_getitem_; + n->f_visit_prim_value_ = f_visit_prim_value_; + n->f_visit_string_imm_ = f_visit_string_imm_; + n->f_visit_data_type_imm_ = f_visit_data_type_imm_; + n->f_visit_binding = f_visit_binding; + n->f_visit_var_binding_ = f_visit_var_binding_; + n->f_visit_match_cast_ = f_visit_match_cast_; + n->f_visit_binding_block = f_visit_binding_block; + n->f_visit_binding_block_ = f_visit_binding_block_; + n->f_visit_dataflow_block_ = f_visit_dataflow_block_; + n->f_visit_var_def = f_visit_var_def; + n->f_visit_var_def_ = f_visit_var_def_; + n->f_visit_dataflow_var_def_ = f_visit_dataflow_var_def_; + n->f_visit_span = f_visit_span; + return PyExprMutator(n); + } + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(PyExprMutator, ObjectRef, PyExprMutatorNode); +}; + +TVM_REGISTER_GLOBAL("relax.MakePyExprVisitor").set_body_typed(PyExprVisitor::MakePyExprVisitor); + +TVM_REGISTER_GLOBAL("relax.PyExprVisitorVisitExpr") + .set_body_typed([](PyExprVisitor visitor, const Expr& expr) { visitor->VisitExpr(expr); }); + +TVM_REGISTER_GLOBAL("relax.PyExprVisitorVisitBinding") + .set_body_typed([](PyExprVisitor visitor, const Binding& binding) { + visitor->VisitBinding(binding); + }); + +TVM_REGISTER_GLOBAL("relax.PyExprVisitorVisitBindingBlock") + .set_body_typed([](PyExprVisitor visitor, const BindingBlock& block) { + visitor->VisitBindingBlock(block); + }); + +TVM_REGISTER_GLOBAL("relax.PyExprVisitorVisitVarDef") + .set_body_typed([](PyExprVisitor visitor, const Var& var) { visitor->VisitVarDef(var); }); + +TVM_REGISTER_GLOBAL("relax.ExprVisitorVisitExpr") + .set_body_typed([](PyExprVisitor visitor, const Expr& expr) { + visitor->ExprVisitor::VisitExpr(expr); + }); + +TVM_REGISTER_GLOBAL("relax.ExprVisitorVisitBinding") + .set_body_typed([](PyExprVisitor visitor, const Binding& binding) { + visitor->ExprVisitor::VisitBinding(binding); + }); + +TVM_REGISTER_GLOBAL("relax.ExprVisitorVisitBindingBlock") + .set_body_typed([](PyExprVisitor visitor, const BindingBlock& block) { + visitor->ExprVisitor::VisitBindingBlock(block); + }); + +TVM_REGISTER_GLOBAL("relax.ExprVisitorVisitVarDef") + .set_body_typed([](PyExprVisitor visitor, const Var& var) { + visitor->ExprVisitor::VisitVarDef(var); + }); + +TVM_REGISTER_GLOBAL("relax.ExprVisitorVisitSpan") + .set_body_typed([](PyExprVisitor visitor, const Span& span) { + visitor->ExprVisitor::VisitSpan(span); + }); + +TVM_REGISTER_GLOBAL("relax.MakePyExprMutator").set_body_typed(PyExprMutator::MakePyExprMutator); + +TVM_REGISTER_GLOBAL("relax.PyExprMutatorVisitExpr") + .set_body_typed([](PyExprMutator mutator, const Expr& expr) { + return mutator->VisitExpr(expr); + }); + +TVM_REGISTER_GLOBAL("relax.PyExprMutatorVisitBinding") + .set_body_typed([](PyExprMutator mutator, const Binding& binding) { + mutator->VisitBinding(binding); + }); + +TVM_REGISTER_GLOBAL("relax.PyExprMutatorVisitBindingBlock") + .set_body_typed([](PyExprMutator mutator, const BindingBlock& block) { + return mutator->VisitBindingBlock(block); + }); + +TVM_REGISTER_GLOBAL("relax.PyExprMutatorVisitVarDef") + .set_body_typed([](PyExprMutator mutator, const Var& var) { + return mutator->VisitVarDef(var); + }); + +TVM_REGISTER_GLOBAL("relax.ExprMutatorVisitExpr") + .set_body_typed([](PyExprMutator mutator, const Expr& expr) { + return mutator->ExprMutator::VisitExpr(expr); + }); + +TVM_REGISTER_GLOBAL("relax.ExprMutatorVisitBinding") + .set_body_typed([](PyExprMutator mutator, const Binding& binding) { + return mutator->ExprMutator::VisitBinding(binding); + }); + +TVM_REGISTER_GLOBAL("relax.ExprMutatorVisitBindingBlock") + .set_body_typed([](PyExprMutator mutator, const BindingBlock& block) { + return mutator->ExprMutator::VisitBindingBlock(block); + }); + +TVM_REGISTER_GLOBAL("relax.ExprMutatorVisitVarDef") + .set_body_typed([](PyExprMutator mutator, const Var& var) { + return mutator->ExprMutator::VisitVarDef(var); + }); + +TVM_REGISTER_GLOBAL("relax.PyExprMutatorVisitExprPostOrder") + .set_body_typed([](PyExprMutator mutator, const Expr& expr) { + return mutator->VisitExprPostOrder(expr); + }); + +TVM_REGISTER_GLOBAL("relax.PyExprMutatorVisitWithNewScope") + .set_body_typed([](PyExprMutator mutator, const Expr& expr) { + return mutator->VisitWithNewScope(expr); + }); + +TVM_REGISTER_GLOBAL("relax.PyExprMutatorLookupBinding") + .set_body_typed([](PyExprMutator mutator, const Var& var) { + return mutator->LookupBinding(var); + }); + +TVM_REGISTER_GLOBAL("relax.PyExprMutatorWithStructInfo") + .set_body_typed([](PyExprMutator mutator, Var var, StructInfo sinfo) { + return mutator->WithStructInfo(var, sinfo); + }); + +TVM_REGISTER_GLOBAL("relax.PyExprMutatorSetVarRemap") + .set_body_typed([](PyExprMutator mutator, Id id, Var var) { + return mutator->var_remap_[id] = var; + }); + +TVM_REGISTER_GLOBAL("relax.PyExprMutatorGetVarRemap") + .set_body_typed([](PyExprMutator mutator, Id id) { return mutator->var_remap_[id]; }); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc new file mode 100644 index 000000000000..a3fa4412180c --- /dev/null +++ b/src/relax/op/op.cc @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include +#include +#include +#include + +namespace tvm { +namespace relax { + +// call_tir + +StructInfo InferStructInfoCallTIR(const Call& call, const BlockBuilder& ctx) { + if (call->sinfo_args.size() != 1) { + ctx->ReportFatal(Diagnostic::Error(call->span) + << "sinfo_args should have exact 1 output struct info."); + } + return call->sinfo_args[0]; +} + +RELAY_REGISTER_OP("relax.call_tir") + .set_num_inputs(3) + .add_argument("func", "Expr", "The destination-passing-style function.") + .add_argument("args", "Tuple", "The input arguments.") + .add_argument("packed_ints", "Expr", + "ShapeExpr representing a tuple of ints to unpack during runtime. Omitted from " + "args if unused") + .set_attr("FInferStructInfo", InferStructInfoCallTIR); + +Expr MakeCallTIR(Expr func, Tuple args, Array out_sinfo_list, + Optional packed_ints) { + for (const TensorStructInfo& sinfo : out_sinfo_list) { + const auto* shape = sinfo->shape.as(); + CHECK(shape != nullptr) << "out_sinfo of call_tir should have defined ShapeExpr as shape. " + "However, one given structure info is " + << sinfo; + } + + StructInfo out_sinfo{nullptr}; + if (out_sinfo_list.size() == 1) { + out_sinfo = out_sinfo_list[0]; + } else { + out_sinfo = TupleStructInfo({out_sinfo_list.begin(), out_sinfo_list.end()}); + } + + static const Op& op = Op::Get("relax.call_tir"); + Call call; + if (!packed_ints) { + // don't use additional optional argument + call = Call(op, {func, args}, {}, {out_sinfo}); + } else { + call = Call(op, {func, args, packed_ints.value()}, {}, {out_sinfo}); + } + return call; +} + +TVM_REGISTER_GLOBAL("relax.op.call_tir").set_body_typed(MakeCallTIR); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/op/op_common.cc b/src/relax/op/op_common.cc new file mode 100644 index 000000000000..260f71e7bfb6 --- /dev/null +++ b/src/relax/op/op_common.cc @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "op_common.h" + +#include + +namespace tvm { +namespace relax { + +Array GetInputTensorStructInfo(const Call& call, const BlockBuilder& ctx) { + Op op = Downcast(call->op); + int n_input = op->arguments.size(); + if (static_cast(call->args.size()) != n_input) { + ctx->ReportFatal(Diagnostic::Error(call->span) + << op << " op should have " << n_input << " arguments"); + } + Array input_tensor_sinfo; + input_tensor_sinfo.reserve(n_input); + for (int i = 0; i < n_input; ++i) { + const auto* sinfo = GetStructInfoAs(call->args[i]); + if (sinfo == nullptr) { + ctx->ReportFatal(Diagnostic::Error(call->span) + << op << " requires the input " << op->arguments[i]->name + << " to be Tensor. However, the given one is " + << call->args[i]->struct_info_->GetTypeKey()); + } + input_tensor_sinfo.push_back(GetRef(sinfo)); + } + return input_tensor_sinfo; +} + +Optional> InferBinaryBroadcastShape(const Call& call, const BlockBuilder& ctx, + const Array& x1_shape, + const Array& x2_shape) { + arith::Analyzer* analyzer = ctx->GetAnalyzer(); + int x1_ndim = x1_shape.size(); + int x2_ndim = x2_shape.size(); + int max_ndim = std::max(x1_ndim, x2_ndim); + + std::vector output_shape; + output_shape.reserve(max_ndim); + + int i = 1; + for (; i <= std::min(x1_ndim, x2_ndim); ++i) { + const PrimExpr& dim0 = x1_shape[x1_ndim - i]; + const PrimExpr& dim1 = x2_shape[x2_ndim - i]; + const auto* int_dim0 = dim0.as(); + const auto* int_dim1 = dim1.as(); + if (int_dim0 != nullptr && int_dim0->value == 1) { + output_shape.push_back(dim1); + } else if (int_dim1 != nullptr && int_dim1->value == 1) { + output_shape.push_back(dim0); + } else if (analyzer->CanProveEqual(dim0, dim1)) { + output_shape.push_back(dim0); + } else if (int_dim0 && int_dim1 && int_dim0->value != int_dim1->value) { + ctx->ReportFatal(Diagnostic::Error(call->span) + << "In " << call->op << ", the first input shape at dim " << x1_ndim - i + << " is " << dim0 << " and the second input shape at dim " << x2_ndim - i + << " is " << dim1 << ", which are not broadcastable."); + } else { + // Use simple fallback when shape mismatch. + return NullOpt; + } + } + auto& longer_shape = (x1_ndim > x2_ndim) ? x1_shape : x2_shape; + for (; i <= max_ndim; ++i) { + output_shape.push_back(longer_shape[max_ndim - i]); + } + return Array(output_shape.rbegin(), output_shape.rend()); +} + +std::vector NormalizeAxes(const Call& call, const BlockBuilder& ctx, int ndim, + const Array& axes) { + ICHECK_NE(ndim, kUnknownNDim) << "The ndim is required to be known for this function."; + std::vector appeared_dims_set; + std::vector axes_non_neg; + appeared_dims_set.resize(ndim, /*value=*/false); + axes_non_neg.reserve(axes.size()); + for (const Integer& axis : axes) { + int _axis = axis->value; + if (_axis < -ndim || _axis >= ndim) { + ctx->ReportFatal(Diagnostic::Error(call->span) + << "In " << call->op << ", the input axis " << _axis + << " is out of range. The input tensor has " << ndim + << " dimensions, so axis should be in range [" << -ndim << ", " << ndim + << ")."); + } else if (_axis < 0) { + _axis = ndim + _axis; + } + + if (appeared_dims_set[_axis]) { + ctx->ReportFatal(Diagnostic::Error(call->span) + << "In " << call->op + << ", the input axes is required to be non-repetitive. However, there are " + "multiple given axes referring to axis " + << _axis); + } + appeared_dims_set[_axis] = true; + axes_non_neg.push_back(_axis); + } + return axes_non_neg; +} + +} // namespace relax +} // namespace tvm diff --git a/src/relax/op/op_common.h b/src/relax/op/op_common.h new file mode 100644 index 000000000000..8e362bb4d55c --- /dev/null +++ b/src/relax/op/op_common.h @@ -0,0 +1,285 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file op_common.h + * \brief A set of utilities and common functionality + * for Relax ops. + */ +#ifndef TVM_RELAX_OP_OP_COMMON_H_ +#define TVM_RELAX_OP_OP_COMMON_H_ + +#include +#include +#include +#include +#include + +#include +#include + +namespace tvm { +namespace relax { + +/************ Op input struct info getter ************/ + +/*! + * \brief Get the tensor struct info of the operator input. + * \param call The context Call to the operator. + * \param ctx The error reporting context. + * \return The tensor struct info of each input. + * \note This function require every input to be Tensor. The number of call arguments is required + * to match the number of inputs of the op being called. + */ +Array GetInputTensorStructInfo(const Call& call, const BlockBuilder& ctx); + +/*! + * \brief Get the tensor struct info of the unary operator input. + * \param call The context Call to the operator. + * \param ctx The error reporting context. + * \return The tensor struct info of the unary operator input. + * \throw Throw exception if the number of input is not one, or the struct info of the input is not + * a tensor struct info. + */ +inline TensorStructInfo GetUnaryInputTensorStructInfo(const Call& call, const BlockBuilder& ctx) { + return GetInputTensorStructInfo(call, ctx)[0]; +} + +/************ Op registration macro ************/ + +/*! + * \brief Quick helper macro to register the operator to registry + * \param OpRegName The name of operator to register. The name passed in will + * be prepended with a prefix "relax." as the identifier string in the operator registry. + */ +#define RELAX_REGISTER_UNARY_OP(OpRegName) \ + TVM_REGISTER_OP("relax." OpRegName) \ + .set_num_inputs(1) \ + .add_argument("x", "Tensor", "The input tensor.") + +/*! + * \brief Quick helper macro to expose a make-function to construct the operator. + * \param OpName The name of the operator as well as the make-function name, which will + * be prepended with a prefix "relax.op." as the FFI identifier string for the make function, + * \param OpRegName The identifier of the operator in the registry. + */ +#define RELAX_UNARY_OP_INTERFACE(OpName, OpRegName) \ + Expr OpName(Expr x) { \ + static const Op& op = Op::Get("relax." OpRegName); \ + return Call(op, {std::move(x)}, Attrs(), {}); \ + } \ + TVM_REGISTER_GLOBAL("relax.op." OpRegName).set_body_typed(OpName) + +/************ Utilities ************/ + +/*! + * \brief Infer the struct info for unary elementwise ops. + * \param call The context Call to the operator. + * \param ctx The error reporting context. + * \param f_compute_out_dtype The function to compute the output dtype, with + * signature DataType f_compute_out_dtype(const TensorStructInfo& input_sinfo). + * \tparam require_float_dtype whether this op requires the input dtype to be float + * \tparam Ftype the type of f_compute_out_dtype + * \return The inferred struct info. + */ +template +inline StructInfo InferStructInfoUnary(const Call& call, const BlockBuilder& ctx, + FType f_compute_out_dtype) { + TensorStructInfo input_sinfo = GetUnaryInputTensorStructInfo(call, ctx); + if (require_float_dtype && !input_sinfo->IsUnknownDtype() && !input_sinfo->dtype.is_float()) { + ctx->ReportFatal( + Diagnostic::Error(call->span) + << call->op + << " requires the input tensor to have float dtype. However, the given input dtype is " + << input_sinfo->dtype); + } + auto output_sinfo = make_object(*input_sinfo.get()); + output_sinfo->dtype = f_compute_out_dtype(input_sinfo); + return TensorStructInfo(output_sinfo); +} + +/*! + * \brief Infer the struct info for unary arithmetic elementwise ops. It's also + * used in some NN operators. + * \param call The context Call to the operator. + * \param ctx The error reporting context. + * \tparam require_float_dtype whether this op requires the input dtype to be float + * \return The inferred struct info. + */ +template +StructInfo InferStructInfoUnaryArith(const Call& call, const BlockBuilder& ctx) { + return InferStructInfoUnary( + call, ctx, [](const TensorStructInfo& input_sinfo) { return input_sinfo->dtype; }); +} + +/************ Utilities ************/ + +/*! + * \brief Infer the output datatype for binary arithmetic operators. + * \param call The context Call to the operator. + * \param ctx The error reporting context. + * \param x1_sinfo The struct info of the first operand + * \param x2_sinfo The struct info of the second operand + * \return The inferred output dtype. + * \throw Throw exception if the dtype of two input TensorStructInfo don’t match + */ +inline DataType InferBinaryArithOpOutDtype(const Call& call, const BlockBuilder& ctx, + const TensorStructInfo& x1_sinfo, + const TensorStructInfo& x2_sinfo) { + if (x1_sinfo->IsUnknownDtype() || x2_sinfo->IsUnknownDtype()) { + return DataType::Void(); + } else if (x1_sinfo->dtype != x2_sinfo->dtype) { + ctx->ReportFatal(Diagnostic::Error(call->span) + << "Data types " << x1_sinfo->dtype << " and " << x2_sinfo->dtype + << " must be equal for binary operators"); + } + return x1_sinfo->dtype; +} + +/*! + * \brief Infer the output shape for binary broadcast operators. + * \param call The context Call to the operator. + * \param ctx The error reporting context. + * \param x1_shape The shape of the first operand. + * \param x2_shape The shape of the second operand. + * \return The inferred output shape after broadcasting. Or `NullOpt` if the output shape cannot be + * determined due to symbolic broadcast. + */ +Optional> InferBinaryBroadcastShape(const Call& call, const BlockBuilder& ctx, + const Array& x1_shape, + const Array& x2_shape); + +/*! + * \brief Convert all axes to non-negative indices, and meanwhile check if the given array of axes + * are all in range and non-repetitive with regards to the given ndim. + * \param call The context Call to the operator. + * \param ctx The error reporting context. + * \param ndim The ndim constraint, which is required to be known already. + * \param axes The axis indices to be checked + * \return The input axes in non-negative indexing. + * \throw Throw exception if there exists out-of-range axis index or repetitive indices. + */ +std::vector NormalizeAxes(const Call& call, const BlockBuilder& ctx, int ndim, + const Array& axes); + +/*! + * \brief Convert the given axis to non-negative index. Meanwhile check if the axis is in range + * with regards to the given ndim. + * \param call The context Call to the operator. + * \param ctx The error reporting context. + * \param ndim The ndim constraint. + * \param axis The axis index to be checked + * \return The input axis in non-negative indexing. + * \throw Throw exception the given axis is out-of-range. + */ +inline int NormalizeAxis(const Call& call, const BlockBuilder& ctx, int ndim, int axis) { + return NormalizeAxes(call, ctx, ndim, {axis})[0]; +} + +/*! + * \brief Convert an array of integers to int64 dtype. + * \param int_imms The input IntImms to be converted. + * \return The conversion result, where every IntImm has dtype int64 + */ +inline Array ConvertIntImmToInt64(const Array& int_imms) { + return int_imms.Map([](const IntImm& i) { return Downcast(cast(DataType::Int(64), i)); }); +} + +/************ Utilities for NN operators ************/ + +/*! + * \brief Complete the padding to a 4-length array. + * - If the padding length is 1, the same padding is used on all top/left/bottom/right sides + * - If the padding length is 2, top/bottom sides use padding[0] and left/right use padding[1] + * - If the padding length is 4, padding is in the order of (top, left, bottom, right) + * \param padding The given padding to be completed + * \return The completed padding. + * \throws Throws error if the input padding length is neither 1, 2 or 4. + */ +inline Array GetCompletePadding2D(Array padding) { + if (padding.size() == 1) { + return {padding[0], padding[0], padding[0], padding[0]}; + } else if (padding.size() == 2) { + return {padding[0], padding[1], padding[0], padding[1]}; + } else if (padding.size() == 4) { + return padding; + } + LOG(FATAL) << "The input padding length is expected to be either 1, 2 or 4. However, the given " + "padding is " + << padding; + throw; +} + +/*! + * \brief Check if the given tensor layout can be converted to the given target layout. + * If convertible, return the tensor layout and the bijective conversion in tir::Layout and + * tir::BijectiveLayout accordingly. + * \param call The context Call to the operator. + * \param ctx The error reporting context. + * \param tensor_layout The tensor layout to be checked + * \param tgt_layout The target layout to be matched + * \param tensor_name The name of the input tensor + * \return The tensor layout and the bijective conversion in tir::Layout and tir::BijectiveLayout + * accordingly. + */ +inline std::pair CheckTensorLayout(const Call& call, + const BlockBuilder& ctx, + const String& tensor_layout, + const String& tgt_layout, + const String& tensor_name) { + tir::Layout _tensor_layout(tensor_layout, DataType::Int(64)); + tir::BijectiveLayout tensor2tgt(_tensor_layout, tir::Layout(tgt_layout, DataType::Int(64))); + if (!tensor2tgt.defined()) { + ctx->ReportFatal(Diagnostic::Error(call->span) + << call->op << " requires the given " << tensor_name + << " layout to be convertible from " << tgt_layout + << " layout. However, the given layout " << tensor_layout + << " is not convertible."); + } + return {_tensor_layout, tensor2tgt}; +} + +/*! + * \brief Check if the given tensor struct info has expected ndim per the given layout (or the ndim + * is unknown), and try to cast the shape to ShapeExpr. + * \param call The context Call to the operator. + * \param ctx The error reporting context. + * \param sinfo The input tensor struct info to be checked. + * \param layout The layout that the given tensor is expected to have. + * \return The shape of the input tensor in ShapeExpr, or `NullOpt` if the shape is unknown. + */ +inline Optional CheckNdimPerLayoutAndGetShape(const Call& call, const BlockBuilder& ctx, + const TensorStructInfo& sinfo, + const tir::Layout& layout) { + if (!sinfo->IsUnknownNdim() && sinfo->ndim != static_cast(layout.ndim())) { + ctx->ReportFatal(Diagnostic::Error(call->span) + << "In " << call->op << ", layout " << layout << " requires the input to be " + << layout.ndim() << "-dim tensor. However, the given input has ndim " + << sinfo->ndim); + } + if (const auto* shape_expr = sinfo->shape.as()) { + return GetRef(shape_expr); + } + return NullOpt; +} + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_OP_OP_COMMON_H_ diff --git a/src/relax/op/tensor/binary.cc b/src/relax/op/tensor/binary.cc new file mode 100644 index 000000000000..dd61091f7aaa --- /dev/null +++ b/src/relax/op/tensor/binary.cc @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file binary.cc + * \brief binary broadcast operators. + */ + +#include "binary.h" + +#include + +namespace tvm { +namespace relax { + +template +StructInfo InferStructInfoBroadcast(const Call& call, const BlockBuilder& ctx, + FType f_compute_out_dtype) { + Array input_sinfo = GetInputTensorStructInfo(call, ctx); + TensorStructInfo x1_sinfo = input_sinfo[0]; + TensorStructInfo x2_sinfo = input_sinfo[1]; + + // DateType + DataType output_dtype = f_compute_out_dtype(call, ctx, x1_sinfo, x2_sinfo); + + // ndims + int output_ndim; + if (x1_sinfo->IsUnknownNdim() || x2_sinfo->IsUnknownNdim()) { + output_ndim = kUnknownNDim; + } else { + output_ndim = std::max(x1_sinfo->ndim, x2_sinfo->ndim); + } + + const auto* x1_shape = x1_sinfo->shape.as(); + const auto* x2_shape = x2_sinfo->shape.as(); + // Shapes and ndims + if (x1_shape && x2_shape) { + // If all inputs have shapes, directly infer shapes + Optional> output_shape = + InferBinaryBroadcastShape(call, ctx, x1_shape->values, x2_shape->values); + if (!output_shape.defined()) { + return TensorStructInfo(output_dtype, /*ndim=*/output_ndim); + } else { + ICHECK_EQ(static_cast(output_shape.value().size()), output_ndim); + return TensorStructInfo(ShapeExpr(output_shape.value()), output_dtype); + } + } else if (x1_sinfo->shape.defined() && x1_sinfo->shape.same_as(x2_sinfo->shape)) { + return TensorStructInfo(x1_sinfo->shape.value(), output_dtype); + } else { + return TensorStructInfo(output_dtype, /*ndim=*/output_ndim); + } +} + +StructInfo InferStructInfoBroadcastArith(const Call& call, const BlockBuilder& ctx) { + return InferStructInfoBroadcast(call, ctx, InferBinaryArithOpOutDtype); +} + +StructInfo InferStructInfoBroadcastCMP(const Call& call, const BlockBuilder& ctx) { + return InferStructInfoBroadcast( + call, ctx, + [](const Call& call, const BlockBuilder& ctx, const TensorStructInfo& x1_sinfo, + const TensorStructInfo& x2_sinfo) { return DataType::Bool(); }); +} + +/***************** Arithmetic operators *****************/ + +RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(add); +RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(multiply); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/op/tensor/binary.h b/src/relax/op/tensor/binary.h new file mode 100644 index 000000000000..1402f5ac6bfa --- /dev/null +++ b/src/relax/op/tensor/binary.h @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file binary.h + * \brief The functions to make Relax binary arithmetic and comparison operator calls. + */ +#ifndef TVM_RELAX_OP_TENSOR_BINARY_H_ +#define TVM_RELAX_OP_TENSOR_BINARY_H_ + +#include "../op_common.h" + +namespace tvm { +namespace relax { + +/*! + * \brief Quick helper macro + * - Expose a make function to construct the node. + * - Register op to the registry. + * \param OpName The name of operator to register. The name passed in will + * 1. be prepended with a prefix "relax.op." as the FFI identifier string for the make function, + * 2. be prepended with a prefix "relax." as the identifier string in the operator registry. + */ +#define RELAX_REGISTER_BINARY_OP_AND_IMPL(OpName) \ + Expr OpName(Expr x1, Expr x2) { \ + static const Op& op = Op::Get("relax." #OpName); \ + return Call(op, {x1, x2}, Attrs(), {}); \ + } \ + TVM_REGISTER_GLOBAL("relax.op." #OpName).set_body_typed(OpName); \ + TVM_REGISTER_OP("relax." #OpName) \ + .set_num_inputs(2) \ + .add_argument("x1", "Tensor", "The first input tensor.") \ + .add_argument("x2", "Tensor", "The second input tensor.") + +#define RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(OpName) \ + RELAX_REGISTER_BINARY_OP_AND_IMPL(OpName).set_attr( \ + "FInferStructInfo", InferStructInfoBroadcastArith) + +#define RELAX_REGISTER_CMP_OP_AND_IMPL(OpName) \ + RELAX_REGISTER_BINARY_OP_AND_IMPL(OpName).set_attr( \ + "FInferStructInfo", InferStructInfoBroadcastCMP) + +/***************** Arithmetic operators *****************/ + +/*! \brief Addition with numpy-style broadcasting. */ +Expr add(Expr x1, Expr x2); + +/*! \brief Multiplication with numpy-style broadcasting. */ +Expr multiply(Expr x1, Expr x2); + + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_OP_TENSOR_BINARY_H_ diff --git a/src/relax/utils.cc b/src/relax/utils.cc new file mode 100644 index 000000000000..5846f8116df2 --- /dev/null +++ b/src/relax/utils.cc @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include + +namespace tvm { +namespace relax { + +bool IsBoolScalarType(const Type& ty, bool permit_unknown_rank, bool permit_unknown_dtype) { + const DynTensorTypeNode* tt = ty.as(); + if (!tt) { + return false; + } + bool correct_dtype = tt->dtype.is_bool() || (permit_unknown_dtype && tt->dtype.is_void()); + bool correct_rank = tt->ndim == 0 || (permit_unknown_rank && tt->ndim == -1); + return correct_dtype && correct_rank; +} + +bool IsLeafOrTuple(const Expr& expr) { + return expr.as() || expr.as() || expr.as() || + expr.as() || expr.as(); +} + +} // namespace relax +} // namespace tvm diff --git a/src/te/operation/create_primfunc.cc b/src/te/operation/create_primfunc.cc index 92186a4ffea4..e30047174392 100644 --- a/src/te/operation/create_primfunc.cc +++ b/src/te/operation/create_primfunc.cc @@ -580,5 +580,85 @@ TVM_REGISTER_GLOBAL("te.CreatePrimFunc").set_body([](TVMArgs args, TVMRetValue* *ret = CreatePrimFunc(arg_list, index_dtype_override); }); +// Relax version impl +PrimFunc GenerateAndCompletePrimFunc(const Array& arg_list, + const Array& root_stmts, CreateFuncInfo* info, + const Optional> tir_var_list) { + Array parameters; + Map buffer_map; + for (const te::Tensor& tensor : arg_list) { + Var arg("var_" + tensor->GetNameHint(), PrimType(DataType::Handle())); + parameters.push_back(arg); + auto it = info->tensor2buffers.find(tensor); + ICHECK(it != info->tensor2buffers.end()); + buffer_map.Set(arg, it->second); + } + + // add additional arguments for tir vars that are left unbound by match buffer + if (tir_var_list) { + for (const Var& v : tir_var_list.value()) { + parameters.push_back(v); + } + } + + PrimFunc func = WithAttrs(PrimFunc(/*params=*/std::move(parameters), + /*body=*/SeqStmt::Flatten(root_stmts), + /*ret_type=*/VoidType(), + /*buffer_map=*/std::move(buffer_map)), + {{"global_symbol", String("main")}, {"tir.noalias", Bool(true)}}); + + const auto* complete = runtime::Registry::Get("script.Complete"); + ICHECK(complete); + func = (*complete)(std::move(func), info->root_alloc); + return func; +} + +PrimFunc CreatePrimFuncWithConstants(const Array& arg_list, + const Array& constants, + const Optional>& tir_var_list, + std::optional index_dtype_override) { + // Infomations used in CreatePrimFunc and its sub-functions. + CreateFuncInfo info(arg_list); + // Root body stmts. + Array root_stmts; + // Analyzer + arith::Analyzer analyzer; + + // Step 1. Create ordered array of operations and validate they are supported. + Array order = CollectOrderedOps(arg_list); + + // Step 2. Initialize buffer binds map + InitializeBufferBinds(order, &info); + + // Step 3. Rewrite compute stages into blocks. + for (const te::Operation& op : order) { + RewriteStageToBlock(op, &info, &root_stmts, &analyzer); + } + auto func = GenerateAndCompletePrimFunc(arg_list, root_stmts, &info, tir_var_list); + func = tir::BindParams(func, constants); + if (index_dtype_override.has_value()) { + func = IndexDataTypeNormalizer(index_dtype_override.value()).Rewrite(std::move(func)); + } + auto result = LayoutFreePlaceholdersNormalizer().Process(std::move(func)); + return result; +} + +PrimFunc CreatePrimFunc(const Array& arg_list, + const Optional> tir_var_list, + std::optional index_dtype_override) { + return CreatePrimFuncWithConstants(arg_list, {}, tir_var_list, index_dtype_override); +} + +TVM_REGISTER_GLOBAL("te.CreateRelaxPrimFunc").set_body([](TVMArgs args, TVMRetValue* ret) { + Array arg_list = args[0]; + Optional> tir_var_list = args[1]; + std::optional index_dtype_override{std::nullopt}; + // Add conversion to make std::optional compatible with FFI. + if (args[2].type_code() != kTVMNullptr) { + index_dtype_override = args[2].operator DataType(); + } + *ret = CreatePrimFunc(arg_list, tir_var_list, index_dtype_override); +}); + } // namespace tir } // namespace tvm diff --git a/src/te/operation/create_primfunc.h b/src/te/operation/create_primfunc.h index 4246347a16f3..946f024849bf 100644 --- a/src/te/operation/create_primfunc.h +++ b/src/te/operation/create_primfunc.h @@ -42,6 +42,23 @@ PrimFunc CreatePrimFuncWithConstants(const Array& arg_list, const Array& constants, std::optional index_dtype_override = std::nullopt); +// Relax version +// TODO(relax-team) combine with the relay version +/*! \brief Use Tensor Expression to create a schedulable TensorIR func. */ +PrimFunc CreatePrimFunc(const Array& arg_list, + const Optional> tir_var_list, + std::optional index_dtype_override); + +/*! \brief The same as above but create a PrimFunc with AllocateConstNode. If the size of the + * constants array is N, the last N tensors in arg_list will be treated as constant tensors. + * Constant tensors will not be part of the parameters of the created PrimFunc, instead constants + * will be embedded in the body as AllocateConstNode. + */ +PrimFunc CreatePrimFuncWithConstants(const Array& arg_list, + const Array& constants, + const Optional>& tir_var_list, + std::optional index_dtype_override = std::nullopt); + } // namespace tir } // namespace tvm diff --git a/tests/python/relax/test_blockbuilder.py b/tests/python/relax/test_blockbuilder.py new file mode 100644 index 000000000000..36a22f9712ea --- /dev/null +++ b/tests/python/relax/test_blockbuilder.py @@ -0,0 +1,542 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest +import tvm +import tvm.testing + +from tvm import te, tir, topi +from tvm import relax as rx, relay +from tvm.ir.base import assert_structural_equal +from tvm.relax import ExternFunc +from tvm.tir.function import PrimFunc + + +@tvm.register_func("test.blockbuilder.nop") +def nop(): + pass + + +def test_block_builder(): + m = tir.Var("m", "int64") + n = tir.Var("n", "int64") + x = rx.Var("x", rx.TensorStructInfo([m, n], "float16")) + y = rx.Var("y", rx.TensorStructInfo([n], "float16")) + bb = rx.BlockBuilder() + + bb._begin_binding_block() + gv0 = bb.emit(rx.op.add(x, y)) + bb._begin_dataflow_block() + lv0 = bb.emit(rx.op.multiply(gv0, y)) + gv1 = bb.emit_output(rx.op.multiply(lv0, lv0)) + b0 = bb._end_block() + bb._begin_dataflow_block() + lv1 = bb.emit(rx.op.multiply(gv0, y)) + gv2 = bb.emit_output(rx.op.multiply(lv1, lv1)) + b1 = bb._end_block() + gv3 = bb.emit(rx.op.add(x, y)) + b2 = bb._end_block() + + assert isinstance(b0, rx.DataflowBlock) + assert isinstance(b1, rx.DataflowBlock) + assert not isinstance(b2, rx.DataflowBlock) + + +def test_function_single_block(): + m = tir.Var("m", "int64") + n = tir.Var("n", "int64") + x = rx.Var("x", rx.TensorStructInfo([m, n], "float16")) + y = rx.Var("y", rx.TensorStructInfo([n], "float16")) + bb = rx.BlockBuilder() + + with bb.function("func", [x, y]): + with bb.dataflow(): + lv0 = bb.emit(rx.op.add(x, y)) + assert lv0.name_hint == "lv" + lv1 = bb.emit(rx.op.multiply(lv0, y)) + assert lv1.name_hint == "lv1" + gv0 = bb.emit_output(lv1) + assert gv0.name_hint == "gv" + bb.emit_func_output(gv0) + + func = bb.get()["func"] + assert func.params[0] == x + assert func.params[1] == y + assert func.body.body == gv0 + assert_structural_equal(gv0.struct_info, rx.TensorStructInfo([m, n], "float16")) + assert len(func.body.blocks) == 1 + assert len(func.body.blocks[0].bindings) == 3 + + +def test_function_multi_blocks(): + m = tir.Var("m", "int64") + n = tir.Var("n", "int64") + x = rx.Var("x", rx.TensorStructInfo([m, n], "float16")) + y = rx.Var("y", rx.TensorStructInfo([n], "float16")) + bb = rx.BlockBuilder() + + with bb.function("func", [x, y]): + with bb.dataflow(): + lv0 = bb.emit(rx.op.add(x, y)) + assert lv0.name_hint == "lv" + gv0 = bb.emit_output(lv0) + assert gv0.name_hint == "gv" + gv1 = bb.emit(rx.op.add(gv0, gv0)) + assert gv1.name_hint == "gv1" + with bb.dataflow(): + lv1 = bb.emit(rx.op.add(gv1, gv1)) + assert lv1.name_hint == "lv1" + gv2 = bb.emit_output(gv1) + bb.emit_func_output(gv2) + + func = bb.get()["func"] + + assert_structural_equal(gv2.struct_info, rx.TensorStructInfo([m, n], "float16")) + assert func.params[0] == x + assert func.params[1] == y + assert func.body.body == gv2 + assert len(func.body.blocks) == 3 + assert len(func.body.blocks[0].bindings) == 2 + assert len(func.body.blocks[1].bindings) == 1 + assert len(func.body.blocks[2].bindings) == 2 + + +def test_multi_functions(): + m = tir.Var("m", "int64") + n = tir.Var("n", "int64") + x = rx.Var("x", rx.TensorStructInfo([m, n], "float16")) + y = rx.Var("y", rx.TensorStructInfo([n], "float16")) + bb = rx.BlockBuilder() + + with bb.function("func1", [x, y]): + with bb.dataflow(): + lv0 = bb.emit(rx.op.add(x, y)) + assert lv0.name_hint == "lv" + gv0 = bb.emit_output(lv0) + bb.emit_func_output(gv0) + + with bb.function("func2", [x, y]): + with bb.dataflow(): + lv0 = bb.emit(rx.op.add(y, x)) + # TODO(@yuchen): enable block builder to reset local var unique name map + assert lv0.name_hint == "lv1" + gv0 = bb.emit_output(lv0) + bb.emit_func_output(gv0) + + mod = bb.get() + func1 = mod["func1"] + assert func1.params[0] == x + assert func1.params[1] == y + assert len(func1.body.blocks) == 1 + func2 = mod["func2"] + assert func2.params[0] == x + assert func2.params[1] == y + assert len(func2.body.blocks) == 1 + + +def test_binary_shape_type_deduction(): + m = tir.Var("m", "int64") + n = tir.Var("n", "int64") + k = tir.Var("k", "int64") + x = rx.Var("x", rx.TensorStructInfo([m, 1], "float16")) + y = rx.Var("y", rx.TensorStructInfo([n], "float16")) + z = rx.Var("z", rx.TensorStructInfo([5], "float16")) + w = rx.Var("w", rx.TensorStructInfo([k], "float16")) + bb = rx.BlockBuilder() + + with bb.function("func", [x, y, z, w]): + with bb.dataflow(): + lv0 = bb.emit(rx.op.add(x, y)) + assert_structural_equal(lv0.struct_info, rx.TensorStructInfo([m, n], "float16")) + + lv1 = bb.emit(rx.op.multiply(x, z)) + assert_structural_equal(lv1.struct_info, rx.TensorStructInfo([m, 5], "float16")) + + lv2 = bb.emit(rx.op.multiply(z, w)) + assert isinstance(lv2.struct_info, rx.TensorStructInfo) + assert lv2.struct_info.ndim == 1 + assert lv2.struct_info.dtype == "float16" + + lv3 = bb.emit(rx.op.multiply(y, w)) + assert isinstance(lv3.struct_info, rx.TensorStructInfo) + assert lv3.struct_info.ndim == 1 + assert lv3.struct_info.dtype == "float16" + + gv0 = bb.emit_output(lv3) + bb.emit_func_output(gv0) + + assert isinstance(gv0.checked_type, rx.DynTensorType) + assert gv0.checked_type.ndim == 1 + assert gv0.checked_type.dtype == "float16" + + +def test_emit_match_cast(): + m = tir.Var("m", dtype="int64") + n = tir.Var("n", dtype="int64") + x = rx.Var("tensor_value", rx.TensorStructInfo(dtype="float32", ndim=-1)) + y = rx.Var("shape_value", rx.ShapeStructInfo([16, 8])) + bb = rx.BlockBuilder() + + with bb.function("func", [x, y]): + with bb.dataflow(): + # lv0: Tensor((m, n), "float32") = + # match_cast(x: Tensor(_, "float32"], [m, n)) + lv0 = bb.match_cast(x, rx.TensorStructInfo([m, n], "float32")) + assert isinstance(lv0, rx.DataflowVar) + assert_structural_equal(lv0.struct_info, rx.TensorStructInfo([m, n], "float32")) + + # lv1: Shape = match_cast(shape, rx.ShapeStructInfo([m, n])) + lv1 = bb.match_cast(y, rx.ShapeStructInfo([m, n])) + assert lv1.struct_info == rx.ShapeStructInfo([m, n]) + gv0 = bb.emit_output(lv1) + + bb.emit_func_output(gv0) + func = bb.get()["func"] + block = func.body.blocks[0] + b0, b1 = block.bindings[:2] + assert isinstance(b0, rx.MatchCast) + assert isinstance(b1, rx.MatchCast) + + assert b0.value == x + assert b0.struct_info == rx.TensorStructInfo([m, n], "float32") + assert b0.var == lv0 + + assert b1.value == y + assert b1.struct_info == rx.ShapeStructInfo([m, n]) + assert b1.var == lv1 + + +def test_emit_match_cast_binding_in_dataflow_block(): + bb = rx.BlockBuilder() + + x = rx.Var("x", rx.TensorStructInfo(dtype="float32", ndim=-1)) + m = tir.Var("m", dtype="int64") + gv = rx.Var("gv", rx.TensorStructInfo(dtype="float32", ndim=-1)) + match_cast = rx.MatchCast(gv, x, rx.TensorStructInfo((m,), "float32")) + + with bb.function("main", [x]): + with bb.dataflow(): + bb.emit_normalized(match_cast) + bb.emit_output(gv) + bb.emit_func_output(x) + + func = bb.get()["main"] + block = func.body.blocks[0] + b0 = block.bindings[0] + assert isinstance(b0, rx.MatchCast) + + assert b0.value == x + assert isinstance(b0.struct_info, rx.TensorStructInfo) + assert b0.struct_info.shape[0] == m + assert b0.var == gv + + +def test_normalize(): + m = tir.Var("m", "int64") + n = tir.Var("n", "int64") + + x = rx.Var("x", rx.TensorStructInfo([m, n], "float16")) + y = rx.Var("y", rx.TensorStructInfo([n], "float16")) + bb = rx.BlockBuilder() + + # Call node + add_call = rx.op.multiply(x, y) + + bb.normalize(add_call) + shape = rx.get_shape_of(add_call) + + assert isinstance(shape, rx.ShapeExpr) + assert shape[0] == m + assert shape[1] == n + + # Tuple node + tuple_1 = rx.Tuple([x, y]) + bb.normalize(tuple_1) + assert isinstance(tuple_1.struct_info, rx.TupleStructInfo) + assert isinstance(tuple_1.struct_info.fields[0], rx.TensorStructInfo) + assert isinstance(tuple_1.struct_info.fields[1], rx.TensorStructInfo) + + # Nested Tuple + tuple_2 = rx.Tuple([x, rx.Tuple([x, y])]) + bb.normalize(tuple_2) + type_anno0 = x.checked_type + type_anno1 = y.checked_type + assert_structural_equal( + tuple_2.checked_type, rx.TupleType([type_anno0, rx.TupleType([type_anno0, type_anno1])]) + ) + assert isinstance(tuple_2.struct_info, rx.TupleStructInfo) + assert isinstance(tuple_2.struct_info.fields[0], rx.TensorStructInfo) + assert isinstance(tuple_2.struct_info.fields[1], rx.TupleStructInfo) + assert isinstance(tuple_2.struct_info.fields[1].fields[0], rx.TensorStructInfo) + assert isinstance(tuple_2.struct_info.fields[1].fields[1], rx.TensorStructInfo) + + +def test_call_te(): + bb = rx.BlockBuilder() + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + x = rx.Var("x", rx.TensorStructInfo([n, m], "float32")) + y = rx.Var("y", rx.TensorStructInfo([n, m], "float32")) + z = rx.Var("z", rx.TensorStructInfo([n, m], "float32")) + + def te_func(args, args_dict, msg): + A, B = args + C = args_dict["C"] + D = te.compute((128, 128), lambda i, j: A[i, j] + B[i, j]) + E = te.compute((128, 128), lambda i, j: D[i, j] - C[i, j]) + return E + + with bb.function("rx_func", [x, y, z]): + with bb.dataflow(): + out = bb.emit_output(bb.call_te(te_func, [x, y], {"C": z}, msg="hello")) + bb.emit_func_output(out) + + mod = bb.get() + rx_func = mod["rx_func"] + + assert rx_func.params[0] == x + assert rx_func.params[1] == y + assert rx_func.params[2] == z + assert rx_func.body.body == out + assert len(rx_func.body.blocks) == 1 + assert len(rx_func.body.blocks[0].bindings) == 1 + + +def test_call_te_with_unsupported_shape_arg(): + bb = rx.BlockBuilder() + x = rx.Var("x", rx.TensorStructInfo((200,), "float32")) + s = rx.Var("s", rx.ShapeStructInfo((200,))) + + with pytest.raises(AssertionError): + with bb.function("rx_func", [x]): + out = bb.emit(bb.call_te(topi.reshape, x, s)) + bb.emit_func_output(out) + + +def test_emit_te(): + bb = rx.BlockBuilder() + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + x = rx.Var("x", rx.TensorStructInfo([n, m], "float32")) + y = rx.Var("y", rx.TensorStructInfo([n, m], "float32")) + z = rx.Var("z", rx.TensorStructInfo([n, m], "float32")) + + def te_func(args, args_dict, msg): + A, B = args + C = args_dict["C"] + D = te.compute((128, 128), lambda i, j: A[i, j] + B[i, j]) + E = te.compute((128, 128), lambda i, j: D[i, j] - C[i, j]) + return E + + with bb.function("rx_func", [x, y, z]): + out = bb.emit_te(te_func, [x, y], {"C": z}, msg="hello") + bb.emit_func_output(out) + + mod = bb.get() + rx_func = mod["rx_func"] + + def get_tir_func(): + A = te.placeholder((n, m), dtype="float32", name="A") + B = te.placeholder((n, m), dtype="float32", name="B") + C = te.placeholder((n, m), dtype="float32", name="C") + out = te_func((A, B), {"C": C}, "") + return tvm.te.create_prim_func([A, B, C, out], index_dtype_override="int64") + + # check TIR structure matches expected + assert_structural_equal(mod["te_func"].body, get_tir_func().body) + + # check Relax function calls TIR function with call_tir call + assert rx_func.params[0] == x + assert rx_func.params[1] == y + assert rx_func.params[2] == z + assert rx_func.body.body == out + assert len(rx_func.body.blocks) == 1 + assert len(rx_func.body.blocks[0].bindings) == 1 + + call_node = rx_func.body.blocks[0].bindings[0].value + assert isinstance(call_node, rx.Call) + assert call_node.op == relay.op.get("relax.call_tir") + assert len(call_node.args) == 2 + assert call_node.args[0].name_hint == "te_func" + assert call_node.args[1][0] == x + assert call_node.args[1][1] == y + assert call_node.args[1][2] == z + + +def test_emit_te_multiple(): + bb = rx.BlockBuilder() + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + x = rx.Var("x", rx.TensorStructInfo([n, m], "float32")) + y = rx.Var("y", rx.TensorStructInfo([n, m], "float32")) + z = rx.Var("z", rx.TensorStructInfo([128, m], "float32")) + + def te_func(A): + B = te.compute((128, 128), lambda i, j: A[i, j] + 1) + return B + + with bb.function("rx_func", [x, y]): + x1 = bb.emit_te(te_func, x) + y1 = bb.emit_te(te_func, y) + z1 = bb.emit_te(te_func, z) + bb.emit_func_output(z1) + + mod = bb.get() + rx_func = mod["rx_func"] + + prim_func = [] + for gv in mod.get_global_vars(): + if isinstance(mod[gv], PrimFunc): + prim_func.append(mod[gv]) + + # only two PrimFuncs were generated since two of them are equal so got deduped + assert len(prim_func) == 2 + assert rx_func.body.blocks[0].bindings[0].value.args[0].name_hint == "te_func" + assert rx_func.body.blocks[0].bindings[1].value.args[0].name_hint == "te_func" + assert rx_func.body.blocks[0].bindings[2].value.args[0].name_hint == "te_func1" + + +def test_emit_te_multiple_output(): + bb = rx.BlockBuilder() + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + x = rx.Var("x", rx.TensorStructInfo([n, m], "float32")) + + def te_func(A): + B0, B1 = te.compute((n, m), lambda i, j: (A[i, j] + 1, A[i, j] * 2), name="B") + return (B0, B1) + + with bb.function("rx_func", [x]): + y = bb.emit_te(te_func, x) + z = rx.TupleGetItem(y, 0) + bb.emit_func_output([y, z]) + + rx_func = bb.get()["rx_func"] + + # check call tir output shape is a Tuple of ShapeExpr + assert rx_func.params[0] == x + call_node = rx_func.body.blocks[0].bindings[0].value + assert call_node.op == relay.op.get("relax.call_tir") + assert call_node.args[0].name_hint == "te_func" + assert isinstance(call_node.sinfo_args[0], rx.TupleStructInfo) + assert len(call_node.sinfo_args[0].fields) == 2 + assert isinstance(call_node.sinfo_args[0].fields[0].shape, rx.ShapeExpr) + assert isinstance(call_node.sinfo_args[0].fields[1].shape, rx.ShapeExpr) + + +def test_emit_te_extern(): + bb = rx.BlockBuilder() + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + x = rx.Var("x", rx.TensorStructInfo([n, m], "float32")) + y = rx.Var("y", rx.TensorStructInfo([m, n], "float32")) + + with bb.function("rx_cblas_matmul", [x, y]): + out = bb.emit_te(tvm.contrib.cblas.matmul, x, y, transa=False, transb=False) + bb.emit_func_output(out) + + mod = bb.get() + rx_func = mod["rx_cblas_matmul"] + + # check Relax function calls TIR function with call_tir call + assert rx_func.params[0] == x + assert rx_func.params[1] == y + assert len(rx_func.body.blocks) == 1 + call_node = rx_func.body.blocks[0].bindings[0].value + assert isinstance(call_node, rx.Call) + assert call_node.op == relay.op.get("relax.call_tir") + assert len(call_node.args) == 2 + assert call_node.args[0].name_hint == "matmul" + assert call_node.args[1][0] == x + assert call_node.args[1][1] == y + assert call_node.sinfo_args[0].shape[0] == n + assert call_node.sinfo_args[0].shape[1] == n + + +def test_nested_function_fail(): + m = tir.Var("m", "int64") + n = tir.Var("n", "int64") + x = rx.Var("x", rx.TensorStructInfo([m, n], "float16")) + y = rx.Var("y", rx.TensorStructInfo([n], "float16")) + bb = rx.BlockBuilder() + + with pytest.raises(RuntimeError): + with bb.function("func", [x, y]): + gv0 = bb.emit(rx.op.add(x, x)) + with bb.function("func1", [x, y]): + gv1 = bb.emit(rx.op.add(x, x)) + bb.emit_func_output(gv0) + + +def test_emit_func_output_twice_fail(): + m = tir.Var("m", "int64") + n = tir.Var("n", "int64") + x = rx.Var("x", rx.TensorStructInfo([m, n], "float16")) + y = rx.Var("y", rx.TensorStructInfo([n], "float16")) + bb = rx.BlockBuilder() + + with pytest.raises(RuntimeError): + with bb.function("func", [x, y]): + gv0 = bb.emit(rx.op.add(x, y)) + bb.emit_func_output(gv0) + bb.emit_func_output(gv0) + + +def test_func_params_twice_fail(): + m = tir.Var("m", "int64") + n = tir.Var("n", "int64") + x = rx.Var("x", rx.TensorStructInfo([m, n], "float16")) + y = rx.Var("y", rx.TensorStructInfo([n], "float16")) + bb = rx.BlockBuilder() + + with pytest.raises(RuntimeError): + with bb.function("func", [x, y]): + gv0 = bb.emit(rx.op.add(x, y)) + bb.emit_func_output(gv0, [x]) + + +def test_no_func_params_fail(): + m = tir.Var("m", "int64") + n = tir.Var("n", "int64") + x = rx.Var("x", rx.TensorStructInfo([m, n], "float16")) + y = rx.Var("y", rx.TensorStructInfo([n], "float16")) + bb = rx.BlockBuilder() + + with pytest.raises(RuntimeError): + with bb.function("func"): + gv0 = bb.emit(rx.Call(ExternFunc("test.blockbuilder.nop"), [])) + bb.emit_func_output(gv0) + + +def test_block_builder_scope_recovery(): + bb = rx.BlockBuilder() + + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + x = rx.Var("x", rx.TensorStructInfo([n, m], "float32")) + y = rx.Var("y", rx.TensorStructInfo([m, n], "float32")) + + with pytest.raises(RuntimeError): + # this line fails + with bb.function("func", [x, y]): + gv0 = bb.emit(rx.op.add(x, y)) + + # current should be recovered + assert rx.BlockBuilder.current() is None + + # second attempt to do it correctly. + with bb.function("func", [x, y]): + gv0 = bb.emit(rx.op.add(x, y)) + bb.emit_func_output(gv0) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_expr.py b/tests/python/relax/test_expr.py index 4eeaed1e0b50..902c4785610f 100644 --- a/tests/python/relax/test_expr.py +++ b/tests/python/relax/test_expr.py @@ -15,11 +15,11 @@ # specific language governing permissions and limitations # under the License. import numpy as np -import pytest import tvm from tvm import relax as rx from tvm import tir from tvm.script import relax as R +import pytest def _check_equal(x, y, map_free_vars=False): @@ -255,4 +255,4 @@ def test_datatype_imm(): if __name__ == "__main__": - pytest.main([__file__]) + tvm.testing.main() diff --git a/tests/python/relax/test_expr_functor.py b/tests/python/relax/test_expr_functor.py new file mode 100644 index 000000000000..8165107394c9 --- /dev/null +++ b/tests/python/relax/test_expr_functor.py @@ -0,0 +1,746 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +from tvm import relax, tir +from tvm.ir import Op +from tvm.ir.base import assert_structural_equal +from tvm.relax import PyExprMutator, PyExprVisitor +from tvm.relax.expr import ( + BindingBlock, + Call, + Constant, + DataflowBlock, + DataflowVar, + Expr, + ExternFunc, + Function, + GlobalVar, + If, + MatchCast, + SeqExpr, + ShapeExpr, + Tuple, + TupleGetItem, + PrimValue, + StringImm, + DataTypeImm, + Var, + VarBinding, +) +from tvm.script import relax as R +import pytest + +m, n = tir.Var("m", "int64"), tir.Var("n", "int64") +x = relax.Var("x", R.Tensor([n], "float32")) +y = relax.Var("y", R.Tensor([m, n], "float32")) +bb = relax.BlockBuilder() + + +@relax.expr_functor.visitor +class BasicVisitor(PyExprVisitor): + """Default ExprVisitor""" + + +class ASTLog: + """Helper class to log AST""" + + def __init__(self) -> None: + self.log = [] + self.indent = "\t" + self.level = 0 + + def push_scope(self): + self.level += 1 + + def pop_scope(self): + self.level -= 1 + + def add(self, s: str): + self.log.append(self.indent * self.level + s) + + def __str__(self) -> str: + return "\n".join(self.log) + + +@relax.expr_functor.visitor +class ASTPrinter(PyExprVisitor): + """Print relax AST in structured format. The shape of Node is ignored.""" + + def __init__(self) -> None: + super().__init__() + self.log = ASTLog() + + def visit_constant_(self, op: Constant) -> None: + self.log.add("Constant") + + def visit_global_var_(self, op: GlobalVar) -> None: + self.log.add("GlobalVar") + + def visit_tuple_(self, op: Tuple) -> None: + self.log.add("Tuple") + self.log.push_scope() + for field in op.fields: + self.visit_expr(field) + self.log.pop_scope() + + def visit_var_(self, op: Var) -> None: + self.log.add("Var") + + def visit_dataflow_var_(self, op: DataflowVar) -> None: + self.log.add("DataflowVar") + + def visit_function_(self, op: Function) -> None: + self.log.add("Function") + self.log.push_scope() + for param in op.params: + self.visit_var_def(param) + + self.visit_expr(op.body) + self.log.pop_scope() + + def visit_call_(self, op: Call) -> None: + self.log.add("Call") + self.log.push_scope() + self.visit_expr(op.op) + + for arg in op.args: + self.visit_expr(arg) + self.log.pop_scope() + + def visit_if_(self, op: If) -> None: + self.log.add("If") + self.log.push_scope() + self.visit_expr(op.cond) + self.visit_expr(op.true_branch) + self.visit_expr(op.false_branch) + self.log.pop_scope() + + def visit_op_(self, op: Op) -> None: + self.log.add("Op") + + def visit_tuple_getitem_(self, op: TupleGetItem) -> None: + self.log.add("TupleGetItem") + self.log.push_scope() + self.visit_expr(op.tuple_value) + self.log.pop_scope() + + def visit_prim_value_(self, op: PrimValue) -> None: + self.log.add("PrimValue") + + def visit_string_imm_(self, op: StringImm) -> None: + self.log.add("StringImm") + + def visit_data_type_imm_(self, op: DataTypeImm) -> None: + self.log.add("DataTypeImm") + + def visit_shape_expr_(self, op: ShapeExpr) -> None: + self.log.add("ShapeExpr") + + def visit_extern_func_(self, op: ExternFunc) -> None: + self.log.add("ExternFunc") + + def visit_seq_expr_(self, op: SeqExpr) -> None: + self.log.add("SeqExpr") + self.log.push_scope() + for block in op.blocks: + self.visit_binding_block(block) + self.visit_expr(op.body) + self.log.pop_scope() + + def visit_var_binding_(self, binding: VarBinding) -> None: + self.log.add("VarBinding") + self.log.push_scope() + self.visit_expr(binding.value) + self.visit_var_def(binding.var) + self.log.pop_scope() + + def visit_match_cast_(self, binding: MatchCast) -> None: + self.log.add("MatchCast") + self.log.push_scope() + self.visit_var_def(binding.var) + self.visit_expr(binding.value) + self.log.pop_scope() + + def visit_binding_block_(self, block: BindingBlock) -> None: + self.log.add("BindingBlock") + self.log.push_scope() + for binding in block.bindings: + self.visit_binding(binding) + self.log.pop_scope() + + def visit_dataflow_block_(self, block: DataflowBlock) -> None: + self.log.add("DataflowBlock") + self.log.push_scope() + for binding in block.bindings: + self.visit_binding(binding) + self.log.pop_scope() + + def visit_var_def_(self, var: Var) -> None: + self.log.add("VarDef") + + def visit_dataflow_var_def_(self, var: DataflowVar) -> None: + self.log.add("DataflowVarDef") + + +@relax.expr_functor.mutator +class BasicMutator(PyExprMutator): + """Default ExprMutator""" + + +@relax.expr_functor.mutator +class ASTPostPrinterMutator(PyExprMutator): + """Print relax AST in the post order format.""" + + def __init__(self) -> None: + super().__init__() + self.log = ASTLog() + + def visit_constant_(self, op: Constant) -> Expr: + op = self.visit_expr_post_order(op) + self.log.add("Constant") + return op + + def visit_global_var_(self, op: GlobalVar) -> Expr: + op = self.visit_expr_post_order(op) + self.log.add("GlobalVar") + return op + + def visit_tuple_(self, op: Tuple) -> Expr: + op = self.visit_expr_post_order(op) + self.log.add("Tuple") + return op + + def visit_var_(self, op: Var) -> Expr: + op = self.visit_expr_post_order(op) + self.log.add("Var") + return op + + def visit_dataflow_var_(self, op: DataflowVar) -> Expr: + op = self.visit_expr_post_order(op) + self.log.add("DataflowVar") + return op + + def visit_function_(self, op: Function) -> Expr: + op = self.visit_expr_post_order(op) + self.log.add("Function") + return op + + def visit_call_(self, op: Call) -> Expr: + op = self.visit_expr_post_order(op) + self.log.add("Call") + return op + + def visit_if_(self, op: If) -> Expr: + op = self.visit_expr_post_order(op) + self.log.add("If") + return op + + def visit_op_(self, op: Op) -> Expr: + op = self.visit_expr_post_order(op) + self.log.add("Op") + return op + + def visit_tuple_getitem_(self, op: TupleGetItem) -> Expr: + op = self.visit_expr_post_order(op) + self.log.add("TupleGetItem") + return op + + def visit_prim_value_(self, op: PrimValue) -> Expr: + op = self.visit_expr_post_order(op) + self.log.add("PrimValue") + return op + + def visit_string_imm_(self, op: StringImm) -> Expr: + op = self.visit_expr_post_order(op) + self.log.add("StringImm") + return op + + def visit_data_type_imm_(self, op: DataTypeImm) -> Expr: + op = self.visit_expr_post_order(op) + self.log.add("DataTypeImm") + return op + + def visit_shape_expr_(self, op: ShapeExpr) -> Expr: + op = self.visit_expr_post_order(op) + self.log.add("ShapeExpr") + return op + + def visit_extern_func_(self, op: ExternFunc) -> Expr: + op = self.visit_expr_post_order(op) + self.log.add("ExternFunc") + return op + + def visit_seq_expr_(self, op: SeqExpr) -> Expr: + op = self.visit_expr_post_order(op) + self.log.add("SeqExpr") + return op + + def visit_var_binding_(self, binding: VarBinding) -> None: + """Identical with ExprMutator::VisitBinding_(const VarBindingNode* binding) on the C++ side.""" + new_value = self.visit_expr(binding.value) + new_var = self.visit_var_def(binding.var) + + self.log.add("VarBinding") + if binding.var.same_as(new_var) and binding.value.same_as(new_value): + self.builder_.emit_normalized(binding) + return + + temp = self.with_struct_info(new_var, new_value.struct_info) + if not temp.same_as(new_var): + new_var = temp + self.set_var_remap(binding.var.vid, new_var) + + self.builder_.emit_normalized(VarBinding(new_var, new_value)) + + def visit_match_cast_(self, binding: MatchCast) -> None: + """Identical with ExprMutator::VisitBinding_(const MatchCastNode* binding) on the C++ side.""" + new_var = self.visit_var_def(binding.var) + new_value = self.visit_expr(binding.value) + + temp = self.with_struct_info(new_var, binding.struct_info) + if not temp.same_as(new_var): + new_var = temp + self.set_var_remap(binding.var.vid, new_var) + + self.log.add("MatchCast") + self.builder_.emit_normalized(MatchCast(new_var, new_value, binding.struct_info)) + + def visit_binding_block_(self, block: BindingBlock) -> BindingBlock: + """Identical with ExprMutator::VisitBindingBlock_(const BindingBlockNode* block) on the C++ side.""" + self.builder_._begin_binding_block() + for binding in block.bindings: + self.visit_binding(binding) + self.log.add("BindingBlock") + return self.builder_._end_block() + + def visit_dataflow_block_(self, block: DataflowBlock) -> None: + """Identical with ExprMutator::VisitBindingBlock_(const DataflowBlockNode* block) on the C++ side.""" + self.builder_._begin_dataflow_block() + for binding in block.bindings: + self.visit_binding(binding) + self.log.add("DataflowBlock") + return self.builder_._end_block() + + def visit_var_def_(self, var: Var) -> None: + """Identical with ExprMutator::VisitVarDef_(const VarNode* var) on the C++ side.""" + self.log.add("VarDef") + return var + + def visit_dataflow_var_def_(self, var: DataflowVar) -> None: + """Identical with ExprMutator::VisitVarDef_(const DataflowVarNode* var) on the C++ side.""" + self.log.add("DataflowVarDef") + return var + + +def basic_check(expr, visitor_str, mutator_str): + def visit(f, expr): + if isinstance(expr, relax.Expr): + return f.visit_expr(expr) + elif isinstance(expr, relax.BindingBlock): + return f.visit_binding_block(expr) + + # check no overloading case + basic_visitor = BasicVisitor() + visit(basic_visitor, expr) + + # check the output log + log_visitor = ASTPrinter() + visit(log_visitor, expr) + assert str(log_visitor.log) == visitor_str + + # check no overloading case + basic_mutator = BasicMutator() + # skip normalize GlobalVar since it requires context IRModule to get the checked_type_ + if isinstance(expr, relax.Expr) and not isinstance(expr, relax.GlobalVar): + expr = bb.normalize(expr) + assert_structural_equal(visit(basic_mutator, expr), expr) + + # check the output log and return value + post_log_mutator = ASTPostPrinterMutator() + if isinstance(expr, relax.Expr) and not isinstance(expr, relax.GlobalVar): + expr = bb.normalize(expr) + assert_structural_equal(visit(post_log_mutator, expr), expr) + assert str(post_log_mutator.log) == mutator_str + + +def test_constant(): + basic_check(relax.const(1.0), "Constant", "Constant") + + +def test_var(): + basic_check(x, "Var", "Var") + + +def test_dataflow_var(): + lv = relax.DataflowVar("lv", R.Tensor([n], "float32")) + basic_check(lv, "DataflowVar", "DataflowVar") + + +def test_tuple(): + t = relax.Tuple([x, y]) + basic_check(t, "\n".join(["Tuple", "\tVar", "\tVar"]), "\n".join(["Var", "Var", "Tuple"])) + + +def test_global_var(): + gv = relax.GlobalVar("gv") + basic_check(gv, "GlobalVar", "GlobalVar") + + +def test_seq_expr(): + bindings = [relax.VarBinding(x, relax.const(1))] + blocks = [relax.BindingBlock(bindings)] + seq_expr = relax.SeqExpr(blocks, x) + basic_check( + seq_expr, + "\n".join( + [ + "SeqExpr", + "\tBindingBlock", + "\t\tVarBinding", + "\t\t\tConstant", + "\t\t\tVarDef", + "\tVar", + ] + ), + "\n".join(["Constant", "VarDef", "VarBinding", "BindingBlock", "Var", "SeqExpr"]), + ) + + +def test_shape_expr(): + x = relax.ShapeExpr([m, n]) + basic_check(x, "ShapeExpr", "ShapeExpr") + + +def test_call(): + call_node = relax.op.add(x, y) + basic_check( + call_node, + "\n".join(["Call", "\tOp", "\tVar", "\tVar"]), + "\n".join(["Op", "Var", "Var", "ShapeExpr", "Call"]), + ) + + +def test_if(): + if_node = relax.If(x, x, x) + basic_check( + if_node, + "\n".join(["If", "\tVar", "\tVar", "\tVar"]), + "\n".join(["Var", "Var", "SeqExpr", "Var", "SeqExpr", "If"]), + ) + + +def test_tuple_getitem(): + tuple_getitem_node = relax.TupleGetItem(relax.Tuple([x, y]), 0) + basic_check( + tuple_getitem_node, + "\n".join(["TupleGetItem", "\tTuple", "\t\tVar", "\t\tVar"]), + "\n".join(["Var", "Var", "Tuple", "TupleGetItem"]), + ) + + +def test_binding_block(): + bb._begin_binding_block() + gv0 = bb.emit(relax.op.add(x, y)) + gv1 = bb.match_cast(y, R.Tensor([m, n], "float32")) + b0 = bb._end_block() + basic_check( + b0, + "\n".join( + [ + "BindingBlock", + "\tVarBinding", + "\t\tCall", + "\t\t\tOp", + "\t\t\tVar", + "\t\t\tVar", + "\t\tVarDef", + "\tMatchCast", + "\t\tVarDef", + "\t\tVar", + ] + ), + "\n".join( + [ + "Op", + "Var", + "Var", + "Call", + "ShapeExpr", + "VarDef", + "VarBinding", + "Var", + "ShapeExpr", + "ShapeExpr", + "VarDef", + "MatchCast", + "BindingBlock", + ] + ), + ) + + +def test_dataflow_block(): + bb._begin_dataflow_block() + lv0 = bb.emit(relax.op.add(x, y)) + gv1 = bb.match_cast(y, R.Tensor([m, n], "float32")) + b0 = bb._end_block() + basic_check( + b0, + "\n".join( + [ + "DataflowBlock", + "\tVarBinding", + "\t\tCall", + "\t\t\tOp", + "\t\t\tVar", + "\t\t\tVar", + "\t\tDataflowVarDef", + "\tMatchCast", + "\t\tDataflowVarDef", + "\t\tVar", + ] + ), + "\n".join( + [ + "Op", + "Var", + "Var", + "Call", + "ShapeExpr", + "DataflowVarDef", + "VarBinding", + "Var", + "ShapeExpr", + "ShapeExpr", + "DataflowVarDef", + "MatchCast", + "DataflowBlock", + ] + ), + ) + + +def test_function(): + bindings = [relax.VarBinding(x, relax.const(1))] + blocks = [relax.BindingBlock(bindings)] + seq_expr = relax.SeqExpr(blocks, x) + func = relax.Function([x], seq_expr, R.Tensor([n], "float32")) + basic_check( + func, + "\n".join( + [ + "Function", + "\tVarDef", + "\tSeqExpr", + "\t\tBindingBlock", + "\t\t\tVarBinding", + "\t\t\t\tConstant", + "\t\t\t\tVarDef", + "\t\tVar", + ] + ), + "\n".join( + [ + "VarDef", + "Constant", + "VarDef", + "VarBinding", + "BindingBlock", + "Var", + "SeqExpr", + "Function", + ] + ), + ) + + +def test_extern_func(): + func = relax.ExternFunc("f") + basic_check(func, "ExternFunc", "ExternFunc") + + +def test_inherit(): + # The internal class is not instantiated. + class InternalVisitor(PyExprVisitor): + def __init__(self) -> None: + super().__init__() + self.log = ASTLog() + + def visit_call_(self, op: Call) -> None: + self.log.add("InternalCall") + self.log.push_scope() + self.visit_expr(op.op) + + for arg in op.args: + self.visit_expr(arg) + self.log.pop_scope() + + def visit_var_(self, op: Var) -> None: + self.log.add("Var") + + def visit_op_(self, op: Op) -> None: + self.log.add("Op") + + @relax.expr_functor.visitor + class LeafVisitor(InternalVisitor): + def visit_call_(self, op: Call) -> None: + self.log.add("LeafCall") + self.log.push_scope() + self.visit_expr(op.op) + + for arg in op.args: + self.visit_expr(arg) + self.log.pop_scope() + + call_node = relax.op.add(x, y) + lv = LeafVisitor() + lv.visit_expr(call_node) + assert str(lv.log) == "\n".join(["LeafCall", "\tOp", "\tVar", "\tVar"]) + + +def test_inherit_with_cls(): + # The decorator converts `InternalVisitor` to a wrapper class. + @relax.expr_functor.visitor + class InternalVisitor(PyExprVisitor): + def __init__(self) -> None: + super().__init__() + self.log = ASTLog() + + def visit_call_(self, op: Call) -> None: + self.log.add("InternalCall") + self.log.push_scope() + self.visit_expr(op.op) + + for arg in op.args: + self.visit_expr(arg) + self.log.pop_scope() + + def visit_var_(self, op: Var) -> None: + self.log.add("Var") + + def visit_op_(self, op: Op) -> None: + self.log.add("Op") + + # `InternalVisitor._cls` refers to the original `InternalVisitor` users defined. + @relax.expr_functor.visitor + class LeafVisitor(InternalVisitor._cls): + def visit_call_(self, op: Call) -> None: + self.log.add("LeafCall") + self.log.push_scope() + self.visit_expr(op.op) + + for arg in op.args: + self.visit_expr(arg) + self.log.pop_scope() + + call_node = relax.op.add(x, y) + iv = InternalVisitor() + iv.visit_expr(call_node) + assert str(iv.log) == "\n".join(["InternalCall", "\tOp", "\tVar", "\tVar"]) + + lv = LeafVisitor() + lv.visit_expr(call_node) + assert str(lv.log) == "\n".join(["LeafCall", "\tOp", "\tVar", "\tVar"]) + + +def test_wrong_inherit(): + @relax.expr_functor.visitor + class InternalVisitor(PyExprVisitor): + def visit_call_(self, op: Call) -> None: + pass + + with pytest.raises( + TypeError, + match="Inheritance from a decorated object `LeafVisitor` is not allowed. Please inherit from `LeafVisitor._cls`.", + ): + + @relax.expr_functor.visitor + class LeafVisitor(InternalVisitor): + def visit_call_(self, op: Call) -> None: + pass + + +def test_call_visitor_super(): + @relax.expr_functor.visitor + class InternalVisitor(PyExprVisitor): + def __init__(self) -> None: + super().__init__() + self.log = ASTLog() + + def visit_call_(self, op: Call) -> None: + self.log.add("InternalCall") + super().visit_call_(op) # call PyExprVisitor.visit_call_ + + def visit_var_(self, op: Var) -> None: + self.log.add("Var") + + def visit_op_(self, op: Op) -> None: + self.log.add("Op") + + @relax.expr_functor.visitor + class LeafVisitor(InternalVisitor._cls): + def visit_call_(self, op: Call) -> None: + self.log.add("LeafCall") + super().visit_call_(op) # call InternalVisit.visit_call_ + + call_node = relax.op.add(x, y) + iv = InternalVisitor() + iv.visit_expr(call_node) + assert str(iv.log) == "\n".join(["InternalCall", "Op", "Var", "Var"]) + + lv = LeafVisitor() + lv.visit_expr(call_node) + assert str(lv.log) == "\n".join(["LeafCall", "InternalCall", "Op", "Var", "Var"]) + + +def test_call_mutator_super(): + @relax.expr_functor.mutator + class InternalMutator(PyExprMutator): + def __init__(self) -> None: + super().__init__() + self.log = ASTLog() + + def visit_call_(self, op: Call) -> None: + self.log.add("InternalCall") + return super().visit_call_(op) # call PyExprMutator.visit_call_ + + def visit_var_(self, op: Var) -> None: + self.log.add("Var") + return super().visit_var_(op) # call PyExprMutator.visit_var_ + + def visit_op_(self, op: Op) -> None: + self.log.add("Op") + return super().visit_op_(op) # call PyExprMutator.visit_op_ + + @relax.expr_functor.mutator + class LeafMutator(InternalMutator._cls): + def visit_call_(self, op: Call) -> None: + self.log.add("LeafCall") + return super().visit_call_(op) # call InternalMutator.visit_call_ + + call_node = relax.op.add(x, y) + im = InternalMutator() + im.visit_expr(call_node) + assert str(im.log) == "\n".join(["InternalCall", "Op", "Var", "Var"]) + + lm = LeafMutator() + lm.visit_expr(call_node) + assert str(lm.log) == "\n".join(["LeafCall", "InternalCall", "Op", "Var", "Var"]) + + +if __name__ == "__main__": + tvm.testing.main() From d6b825a2f63f5be2aa61d7680aa07491d1f93e2e Mon Sep 17 00:00:00 2001 From: YuchenJin Date: Mon, 6 Feb 2023 19:30:08 -0800 Subject: [PATCH 2/2] tests and lint --- include/tvm/relax/analysis.h | 2 +- include/tvm/relax/block_builder.h | 2 +- include/tvm/relax/expr_functor.h | 2 +- python/tvm/meta_schedule/utils.py | 2 +- python/tvm/te/operation.py | 2 +- src/relax/op/op.cc | 2 +- src/relax/op/tensor/binary.h | 1 - .../test_analysis_struct_info_analysis.py | 143 ++++++++++++++++++ 8 files changed, 149 insertions(+), 7 deletions(-) diff --git a/include/tvm/relax/analysis.h b/include/tvm/relax/analysis.h index 6125171598b3..ad2bd19aa41a 100644 --- a/include/tvm/relax/analysis.h +++ b/include/tvm/relax/analysis.h @@ -87,7 +87,7 @@ TVM_DLL StructInfo StructInfoFromType(const Type& type); /*! * \return Derive the call's ret value struct info from inputs. - * \param func_info The function struct info. + * \param finfo The function struct info. * \param call The call expression to be derived. * \param ctx The builder context. * \param ana Optional context analyzer to prove symbolic expression equality. diff --git a/include/tvm/relax/block_builder.h b/include/tvm/relax/block_builder.h index d92e5faf279b..7222ae08f956 100644 --- a/include/tvm/relax/block_builder.h +++ b/include/tvm/relax/block_builder.h @@ -180,7 +180,7 @@ class BlockBuilderNode : public Object { /*! * \brief Emit a binding that is already normalized. * - * \param binding A binding whose value is already normalized. + * \param normalized_binding A binding whose value is already normalized. * * \note This function requires binding to be pre-normalized. */ diff --git a/include/tvm/relax/expr_functor.h b/include/tvm/relax/expr_functor.h index ac3ff8d79376..655ecc52b656 100644 --- a/include/tvm/relax/expr_functor.h +++ b/include/tvm/relax/expr_functor.h @@ -514,7 +514,7 @@ class ExprMutator : public ExprMutatorBase { /*! * \brief Post-order rewrite a node and normalize. - * \param T The node type to be rewritten. + * \tparam T The node type to be rewritten. * \param op The node to be rewritten. * \return The node after post rewritten. */ diff --git a/python/tvm/meta_schedule/utils.py b/python/tvm/meta_schedule/utils.py index 9132402b4c9a..1f2cfd34016e 100644 --- a/python/tvm/meta_schedule/utils.py +++ b/python/tvm/meta_schedule/utils.py @@ -89,7 +89,7 @@ def method(*args, **kwargs): return None assert isinstance(cls.__base__, type) - if hasattr(cls, "_type") and cls._type == "TVMDerivedObject": + if hasattr(cls, "_type") and cls._type == "TVMDerivedObject": # type: ignore raise TypeError( ( f"Inheritance from a decorated object `{cls.__name__}` is not allowed. " diff --git a/python/tvm/te/operation.py b/python/tvm/te/operation.py index 1779f6efc595..0351f1f623ae 100644 --- a/python/tvm/te/operation.py +++ b/python/tvm/te/operation.py @@ -19,7 +19,7 @@ # pylint: disable=invalid-name from numbers import Integral as _Integral -from typing import List, Optional, Union +from typing import List, Optional import tvm._ffi import tvm.arith._ffi_api diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc index a3fa4412180c..8640ed79adb0 100644 --- a/src/relax/op/op.cc +++ b/src/relax/op/op.cc @@ -18,9 +18,9 @@ */ #include #include +#include #include #include -#include namespace tvm { namespace relax { diff --git a/src/relax/op/tensor/binary.h b/src/relax/op/tensor/binary.h index 1402f5ac6bfa..a7aea576b685 100644 --- a/src/relax/op/tensor/binary.h +++ b/src/relax/op/tensor/binary.h @@ -64,7 +64,6 @@ Expr add(Expr x1, Expr x2); /*! \brief Multiplication with numpy-style broadcasting. */ Expr multiply(Expr x1, Expr x2); - } // namespace relax } // namespace tvm diff --git a/tests/python/relax/test_analysis_struct_info_analysis.py b/tests/python/relax/test_analysis_struct_info_analysis.py index faf8fedcf4bf..03b98f8a565e 100644 --- a/tests/python/relax/test_analysis_struct_info_analysis.py +++ b/tests/python/relax/test_analysis_struct_info_analysis.py @@ -319,6 +319,149 @@ def fn_info_erased(): assert fopaque.is_base_of(fn_info_shape(1)) +def _check_derive(ctx, finfo, args_sinfo, ret): + gv = rx.GlobalVar("test") + rx.expr._update_struct_info(gv, finfo) + args = [] + for i, sinfo in enumerate(args_sinfo): + arg = rx.Var("arg%i" % i, sinfo) + args.append(arg) + call = rx.Call(gv, args) + derived_ret = rx.analysis.derive_call_ret_struct_info(finfo, call, ctx) + tvm.ir.assert_structural_equal(ret, derived_ret) + + +def test_derive_call_ret_struct_info(): + obj0 = rx.ObjectStructInfo() + prim0 = rx.PrimStructInfo("float32") + + n, m = tir.Var("n0", "int64"), tir.Var("m0", "int64") + bb = rx.BlockBuilder() + # derivation cases + with bb.testing_scope(def_vars=[n, m]): + + def func0(c): + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + x = rx.TensorStructInfo([n, m], "float32") + z = rx.TensorStructInfo([m + c, n], "float32") + return rx.FuncStructInfo([x], z) + + # Tensor => Tensor + _check_derive( + bb, + func0(1), + [rx.TensorStructInfo([10, 11], "float32")], + rx.TensorStructInfo([12, 10], "float32"), + ) + + _check_derive( + bb, + func0(2), + [rx.TensorStructInfo([n, m], "float32")], + rx.TensorStructInfo([m + 2, n], "float32"), + ) + + # passing in information that cannot deduce n, m + # it is still OK as type still matches, return an + # eriased output + _check_derive( + bb, + func0(2), + [rx.TensorStructInfo(ndim=2, dtype="float32")], + rx.TensorStructInfo(ndim=2, dtype="float32"), + ) + + # Error: wrong number of arguments + with pytest.raises(TVMError): + _check_derive( + bb, + func0(2), + [rx.TensorStructInfo(ndim=2, dtype="float32"), obj0], + rx.TensorStructInfo(ndim=2, dtype="float32"), + ) + + # Error:type mismatch + with pytest.raises(TVMError): + _check_derive(bb, func0(2), [obj0], obj0) + + # opaque derivation + fopaque0 = lambda: rx.FuncStructInfo.opaque_func() + fopaque1 = lambda: rx.FuncStructInfo.opaque_func(ret=prim0) + _check_derive(bb, fopaque0(), [obj0, prim0], obj0) + _check_derive(bb, fopaque1(), [obj0, prim0], prim0) + + # recursive tuple derivation + def func_tuple0(c): + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + x0 = rx.TensorStructInfo([n, c], "float32") + x1 = rx.TensorStructInfo([n + c, m], "float32") + z = rx.TupleStructInfo([rx.TensorStructInfo([m, n], "float32")]) + return rx.FuncStructInfo([rx.TupleStructInfo([x0, x1])], z) + + _check_derive( + bb, + func_tuple0(2), + [ + rx.TupleStructInfo( + [ + rx.TensorStructInfo([n, 2], "float32"), + rx.TensorStructInfo([n + 2, 10], "float32"), + ] + ) + ], + rx.TupleStructInfo([rx.TensorStructInfo([10, n], "float32")]), + ) + + def func_tuple1(c): + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + x0 = rx.TensorStructInfo([n, m], "float32") + x1 = rx.TensorStructInfo([n + c, c], "float32") + z = rx.TupleStructInfo([rx.TensorStructInfo([m, n], "float32")]) + return rx.FuncStructInfo([rx.TupleStructInfo([x0, x1])], z) + + # Still OK, to pass erased tensor into n+2, n is captured by other argument. + _check_derive( + bb, + func_tuple1(4), + [ + rx.TupleStructInfo( + [ + rx.TensorStructInfo([n, 4], "float32"), + rx.TensorStructInfo(ndim=2, dtype="float32"), + ] + ) + ], + rx.TupleStructInfo([rx.TensorStructInfo([4, n], "float32")]), + ) + + # tuple length mismatch is not causes an error + with pytest.raises(TVMError): + _check_derive( + bb, + func_tuple0(4), + [rx.TupleStructInfo([rx.TensorStructInfo([n, 4], "float32")])], + rx.TupleStructInfo([rx.TensorStructInfo([10, n], "float32")]), + ) + + # mixed shape types + def func_shape_mixed(c): + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + x0 = rx.ShapeStructInfo([n, m]) + f0 = func_tuple0(c) + z = rx.ShapeStructInfo([m + n, c]) + return rx.FuncStructInfo([x0, f0], z) + + _check_derive( + bb, + func_shape_mixed(3), + [ + rx.ShapeStructInfo([10, 20]), + rx.FuncStructInfo.opaque_func(ret=rx.ShapeStructInfo(ndim=2)), + ], + rx.ShapeStructInfo([30, 3]), + ) + + def _check_lca(lhs, rhs, target): tvm.ir.assert_structural_equal(rx.analysis.struct_info_lca(lhs, rhs), target) tvm.ir.assert_structural_equal(rx.analysis.struct_info_lca(rhs, lhs), target)