diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index 4a00de802c61..b54a067e1c94 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -260,9 +260,10 @@ class GlobalVarNode : public RelayExprNode { */ class GlobalVar : public RelayExpr { public: - TVM_DLL explicit GlobalVar(String name_hint, Type type = {}); + TVM_DLL explicit GlobalVar(String name_hint, Type type = {}, Span span = {}); TVM_DEFINE_OBJECT_REF_METHODS(GlobalVar, RelayExpr, GlobalVarNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(GlobalVarNode); }; // PrimExprs that are useful as runtime containers. diff --git a/include/tvm/relay/attrs/call.h b/include/tvm/relay/attrs/call.h index 167a593ff377..e0b347de1783 100644 --- a/include/tvm/relay/attrs/call.h +++ b/include/tvm/relay/attrs/call.h @@ -35,7 +35,7 @@ namespace relay { * \brief Metadata for calls to TIR functions, useful for program analysis crossing Relay and TIR. */ struct CallLoweredAttrs : public tvm::AttrsNode { - /*! \brief The metadata attached to the call node. */ + /*! \brief Additional metadata attached to the call node. Should be replaced by explict fields. */ Map metadata; TVM_DECLARE_ATTRS(CallLoweredAttrs, "relay.attrs.CallLoweredAttrs") { diff --git a/include/tvm/relay/function.h b/include/tvm/relay/function.h index 5869f878aa85..052d04fe2411 100644 --- a/include/tvm/relay/function.h +++ b/include/tvm/relay/function.h @@ -170,19 +170,40 @@ const FunctionNode* AsOptimizableFunctionNode(const BaseFunc& base_func); * \brief namespace of the attributes that can be attached to a relay::Function. */ namespace attr { -/*! \brief Mark the function as a primitive function. */ + +/*! + * \brief Mark the function as representing a sub-graph which is to be lowered or compiled as + * a unit. For example, the function may represent a kernel which TVM will lower to a PrimFunc. + * If present should be bound to \p Integer(1). May be accompanied by "Compiler", see below. + * The function body should be considered opaque by Relay, and many passes simply ignore these + * functions. + * + * Type: Integer + */ constexpr const char* kPrimitive = "Primitive"; + +/*! + * \brief Mark the function as externally implemented, ie bound in a runtime::Module within the + * IRModule's "external_mods" attribute. If present should be bound to \p Integer(1). Generally + * the only attribute when present. + * + * Type: Integer + */ +constexpr const char* kExtern = "Extern"; + /*! - * \brief Indicate the compiler that should be used for building this function. - * When this is unset or set to "default", the default compilation pipeline will be used. + * \brief Indicates the name of the external codegen 'compiler' that should be used to lower + * or compile the function other than TVM's default lowering pipeline. The name may correspond + * to a TargetKind name. There may be a global function registered under 'relay.ext.{name}'. + * + * Type: String */ constexpr const char* kCompiler = "Compiler"; + /*! \brief Indicate if the function is a closure. */ constexpr const char* kClosure = "Closure"; /*! \brief Store a Var to parameter/Constant mapping on a Function. */ constexpr const char* kParams = "__params__"; -/*! \brief Store the unique external symbol for external compilers. */ -constexpr const char* kExternalSymbol = "ExternalSymbol"; /*! \brief Mark if the function should be avoided being optimized. */ constexpr const char* kSkipOptimization = "SkipOptimization"; /*! \brief Treat the function as a composite operator. */ @@ -193,6 +214,7 @@ constexpr const char* kInline = "Inline"; constexpr const char* kPartitionedFromPattern = "PartitionedFromPattern"; /*! \brief Mark the function as only composed of reshape operations. */ constexpr const char* kReshapeOnly = "relay.reshape_only"; + } // namespace attr } // namespace relay diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index 9f253f8e88ba..694dbb45218c 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -802,24 +802,6 @@ def Inline(): return _ffi_api.Inline() -def InlineComposites(target): - """Perform inlining on the given Relay IR module. The functions originate - from the MergeComposite pass based on an input pattern table will fold back - to main. Currently, this is used for the TRT BYOC which expects a single - primitive function to operate on. - - Parameters - ---------- - target: str - The byoc target for which ops need to fold back to primitive function. - Returns - ------- - ret: tvm.transform.Pass - The registered pass that performs inlining for a Relay IR module. - """ - return _ffi_api.InlineComposites(target) - - def gradient(expr, mod=None, mode="higher_order"): """ Transform the input function, @@ -1386,3 +1368,55 @@ def SplitArgs(max_function_args): The registered pass for constant folding. """ return _ffi_api.SplitArgs(max_function_args) + + +def OutlineCompilerFunctionsWithExistingGlobalSymbols(compiler_filter=""): + """Outlines all literal functions in direct call positions which have a "Compiler" + attribute. + + The outlined functions are bound to unique global vars according to their existing + "global_symbol" attribute. At most one function with the same global symbol is outlined. + + If compiler_filter is non-empty only functions with that as their attribute value are + outlined. + + This pass may be useful for external codegen using the "RelayToTIR" custom pass mechanism + to prepare the IRModule before custom lowering. + + Parameters + ---------- + compiler_filter : String + If non-empty, the 'compiler' attribute to filter on. + + Returns + ------- + ret : tvm.transform.Pass + The pass. + """ + return _ffi_api.OutlineCompilerFunctionsWithExistingGlobalSymbols(compiler_filter) + + +def MarkCompilerFunctionsAsExtern(compiler_filter=""): + """Marks all global functions which have a "Compiler" attribute matching + compiler_filter as 'extern'. + + The function's attributes are replaced with a single "Extern" attribute, and + all calls to the function are switched to use the 'call_lowered' calling convention. + + If compiler_filter is non-empty only functions with that as their attribute value are + outlined. + + This pass may be useful for external codegen using the "RelayToTIR" custom pass mechanism to + cleanup the IRModule after custom lowering. + + Parameters + ---------- + compiler_filter : String + If non-empty, the 'compiler' attribute to filter on. + + Returns + ------- + ret : tvm.transform.Pass + The pass. + """ + return _ffi_api.MarkCompilerFunctionsAsExtern(compiler_filter) diff --git a/src/ir/expr.cc b/src/ir/expr.cc index 399873492f04..a3318bf94fc6 100644 --- a/src/ir/expr.cc +++ b/src/ir/expr.cc @@ -141,10 +141,11 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "range(min=" << op->min << ", ext=" << op->extent << ')'; }); -GlobalVar::GlobalVar(String name_hint, Type type) { +GlobalVar::GlobalVar(String name_hint, Type type, Span span) { ObjectPtr n = make_object(); n->name_hint = std::move(name_hint); n->checked_type_ = std::move(type); + n->span = std::move(span); data_ = std::move(n); } diff --git a/src/parser/tokenizer.h b/src/parser/tokenizer.h index 4ac1ceef26dc..505784e4bf70 100644 --- a/src/parser/tokenizer.h +++ b/src/parser/tokenizer.h @@ -295,8 +295,6 @@ struct Tokenizer { int line = this->line; int column = this->col; - ICHECK_EQ(Peek(), '['); - Next(); std::stringstream type_key; while (More() && Peek() != ']') { type_key << Next(); @@ -498,7 +496,7 @@ struct Tokenizer { auto token = NewToken(TokenType::kQuestion); Next(); return token; - } else if (MatchString("meta")) { + } else if (MatchString("meta[")) { return TokenizeMetaRef(); } else if (next == '#') { return TokenizeAttr(); diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc index 73b44f7361a5..c78f3abd6ecc 100644 --- a/src/relay/backend/te_compiler.cc +++ b/src/relay/backend/te_compiler.cc @@ -168,7 +168,7 @@ class TECompilerImpl : public TECompilerNode { if (const auto* function_node = kv2.second.as()) { // Abandon the existing function annotations. - // Unfortuantely, Optional() is indistinguishable from + // Unfortunately, Optional() is indistinguishable from // NullValue(), and DictAttrs() is nullptr, so to erase the attributes, we // need pass in DictAttrs()), which is a DictAttrs containing no // attributes. @@ -176,8 +176,8 @@ class TECompilerImpl : public TECompilerNode { WithFields(GetRef(function_node), function_node->params, function_node->body, function_node->ret_type, function_node->type_params, /* erase attributes */ DictAttrs(Map())); - // Mark function as 'extern' using the "ExternalSymbol" attribute. - function = WithAttr(std::move(function), attr::kExternalSymbol, kv2.first->name_hint); + // Mark function as 'extern'. + function = WithAttr(std::move(function), attr::kExtern, Integer(1)); module->Add(kv2.first, function); } } @@ -688,7 +688,7 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator { Expr DeviceAwareVisitExpr_(const FunctionNode* function_node) override { if (function_node->HasNonzeroAttr(attr::kPrimitive) || - function_node->GetAttr(attr::kExternalSymbol)) { + function_node->HasNonzeroAttr(attr::kExtern)) { // Nothing to lower inside primitive/external functions. return GetRef(function_node); } else { diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index e0b742a84090..d9730b1b5a4c 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -922,7 +922,7 @@ void VMCompiler::LowerImpl(IRModule mod) { for (const auto& pair : context_.module->functions) { auto gvar = pair.first; if (auto* n = pair.second.as()) { - if (n->GetAttr(attr::kExternalSymbol).defined()) { + if (n->HasNonzeroAttr(attr::kExtern)) { // Already compiled during lowering. continue; } @@ -1131,7 +1131,7 @@ size_t VMCompiler::PopulateGlobalMap() { // Excludes PrimFuncs and externs, which are managed by the primitive_map_. for (const auto& kv : context_.module->functions) { if (const auto* function_node = kv.second.as()) { - if (!function_node->GetAttr(attr::kExternalSymbol)) { + if (!function_node->HasNonzeroAttr(attr::kExtern)) { context_.global_map.emplace(kv.first, context_.global_map.size()); } } diff --git a/src/relay/ir/function.cc b/src/relay/ir/function.cc index bf0dd577a4d2..63e74144e061 100644 --- a/src/relay/ir/function.cc +++ b/src/relay/ir/function.cc @@ -112,7 +112,7 @@ FuncType FunctionNode::func_type_annotation() const { const FunctionNode* AsOptimizableFunctionNode(const BaseFunc& base_func) { if (const auto* function_node = base_func.as()) { if (!function_node->GetAttr(attr::kCompiler).defined() && - !function_node->GetAttr(attr::kExternalSymbol).defined() && + !function_node->HasNonzeroAttr(attr::kExtern) && !function_node->HasNonzeroAttr(attr::kSkipOptimization)) { return function_node; } diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc index 234cafdca150..41b47401de1c 100644 --- a/src/relay/op/nn/nn.cc +++ b/src/relay/op/nn/nn.cc @@ -1012,6 +1012,7 @@ Both `tensor_a` and `tensor_b` can be transposed. For legacy reason, we use NT f - **out**: `(b, m, n)`. )code" TVM_ADD_FILELINE) + .set_attrs_type() .set_num_inputs(2) .add_argument("tensor_a", "3D Tensor", "The first input.") .add_argument("tensor_b", "3D Tensor", "The second input.") diff --git a/src/relay/transforms/compiler_function_utils.cc b/src/relay/transforms/compiler_function_utils.cc new file mode 100644 index 000000000000..b98d089b346a --- /dev/null +++ b/src/relay/transforms/compiler_function_utils.cc @@ -0,0 +1,212 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/transforms/compiler_function_utils.cc + * \brief Helper passes for working with functions with the "Compiler" attribute. + */ + +#include "./compiler_function_utils.h" + +#include "../op/call/call.h" +#include "tvm/relay/analysis.h" +#include "tvm/relay/expr_functor.h" + +namespace tvm { +namespace relay { +namespace transforms { +namespace { + +/*! + * \brief Rewrite calls to inlined "Compiler" functions to global functions. The given + * module will be extended with the newly outlined functions. + */ +class Outliner : public MixedModeMutator { + public: + Outliner(GlobalSymbolCache* cache, std::string compiler_filter, IRModule mod) + : cache_(cache), compiler_filter_(std::move(compiler_filter)), mod_(std::move(mod)) {} + + Expr Rewrite_(const CallNode* pre, const Expr& post) final { + Call new_call = Downcast(post); + if (const auto* function_node = new_call->op.as()) { + Optional opt_compiler = function_node->GetAttr(attr::kCompiler); + if (opt_compiler.defined() && + (compiler_filter_.empty() || opt_compiler.value() == compiler_filter_)) { + auto function = GetRef(function_node); + DCHECK(FreeVars(function).empty()) << "Function marked with '" << attr::kCompiler + << "' attribute should not have free variables"; + // Ask the cache to supply a unique global var for this function. + GlobalVar global_symbol = cache_->GetGlobalSymbol(function); + // Depending on the cache's implementation, two structurally equal (but not object equal) + // functions may be assigned the same global symbol. If so we'll lift it just once, but + // rewrite all the calls. + if (!mod_->ContainGlobalVar(global_symbol->name_hint)) { + function = + WithAttr(std::move(function), tvm::attr::kGlobalSymbol, global_symbol->name_hint); + mod_->Add(global_symbol, function); + } + // Update the call. + return WithFields(new_call, global_symbol); + } + } + return post; + } + + private: + /*! + * \brief A cached mapping from functions to global variables. Depending on the implementation + * the cache may generate fresh symbols or require the function to already have a "global_symbol" + * attribute, and may share symbols between structurally equal functions. + */ + GlobalSymbolCache* cache_; + /*! \brief If non-empty, the "Compiler" attribute value to require on functions to outline. */ + std::string compiler_filter_; + /*! \brief Module being rewritten. */ + IRModule mod_; +}; + +/*! + * \brief Rewrite calls to global "Compiler" functions to use the 'call_lowered' convention. + */ +class CallRewriter : public MixedModeMutator { + public: + CallRewriter(std::string compiler_filter, IRModule mod) + : compiler_filter_(std::move(compiler_filter)), mod_(std::move(mod)) {} + + Expr Rewrite_(const CallNode* pre, const Expr& post) final { + Call new_call = Downcast(post); + if (const auto* global_var_node = new_call->op.as()) { + if (const auto* function_node = + mod_->Lookup(GetRef(global_var_node)).as()) { + Optional opt_compiler = function_node->GetAttr(attr::kCompiler); + if (opt_compiler.defined() && + (compiler_filter_.empty() || opt_compiler.value() == compiler_filter_)) { + Optional opt_global_symbol = + function_node->GetAttr(tvm::attr::kGlobalSymbol); + ICHECK(opt_global_symbol.defined()); + GlobalVar global_symbol = mod_->GetGlobalVar(opt_global_symbol.value()); + CallLoweredAttrs attrs; + attrs.metadata.Set("relay_attrs", new_call->attrs); + return CallLowered(global_symbol, new_call->args, attrs, new_call->span); + } + } + } + return post; + } + + private: + /*! \brief If non-empty, the "Compiler" attribute value to require on functions to outline. */ + std::string compiler_filter_; + /*! \brief Module being rewritten. */ + IRModule mod_; +}; + +} // namespace + +GlobalVar ExistingGlobalSymbolCache::GetGlobalSymbol(const Function& function) { + Optional opt_global_symbol = function->GetAttr(tvm::attr::kGlobalSymbol); + ICHECK(opt_global_symbol.defined()) + << "ExistingGlobalSymbolCache requires all functions to already have a '" + << tvm::attr::kGlobalSymbol << "' attribute"; + std::string global_symbol = opt_global_symbol.value(); + auto itr = global_vars_.find(global_symbol); + if (itr != global_vars_.end()) { + return itr->second; + } + // Ok if function does not have a checked_type, but if it does capture it in the global var. + GlobalVar global_var(global_symbol, function->checked_type_, function->span); + global_vars_.emplace(global_symbol, global_var); + return global_var; +} + +transform::Pass OutlineCompilerFunctions(std::shared_ptr cache, + std::string compiler_filter) { + runtime::TypedPackedFunc pass_func = + [cache = std::move(cache), compiler_filter = std::move(compiler_filter)]( + IRModule mod, transform::PassContext ctx) { + IRModule output_mod = GetRef(mod.CopyOnWrite()); + for (const auto& kv : mod->functions) { + const FunctionNode* function_node = AsOptimizableFunctionNode(kv.second); + if (function_node) { + Expr new_body = + Outliner(cache.get(), compiler_filter, output_mod).VisitExpr(function_node->body); + Function new_function = + WithFields(GetRef(function_node), /*opt_params=*/{}, new_body); + output_mod->Add(kv.first, new_function); + } + } + return output_mod; + }; + + return tvm::transform::CreateModulePass(pass_func, 0, "OutlineCompilerFunctions", {}); +} + +// Any Java programmers in the house? +transform::Pass OutlineCompilerFunctionsWithExistingGlobalSymbols(std::string compiler_filter) { + return OutlineCompilerFunctions(std::make_shared(), + std::move(compiler_filter)); +} + +transform::Pass MarkCompilerFunctionsAsExtern(std::string compiler_filter) { + runtime::TypedPackedFunc pass_func = + [compiler_filter = std::move(compiler_filter)](IRModule mod, transform::PassContext ctx) { + IRModule output_mod = mod->ShallowCopy(); + + // First pass, rewrite the calls. + // We have to do this before marking functions as 'extern' to know which calls to rewrite! + for (const auto& kv : mod->functions) { + if (const auto* function_node = AsOptimizableFunctionNode(kv.second)) { + Expr new_body = + CallRewriter(compiler_filter, output_mod).VisitExpr(function_node->body); + Function new_function = + WithFields(GetRef(function_node), /*opt_params=*/{}, new_body); + output_mod->Update(kv.first, new_function); + } + } + + // Second pass, mark functions as 'extern'. + for (const auto& kv : mod->functions) { + if (const auto* function_node = kv.second.as()) { + Optional opt_compiler = function_node->GetAttr(attr::kCompiler); + if (opt_compiler.defined() && + (compiler_filter.empty() || opt_compiler.value() == compiler_filter)) { + auto new_function = WithFields( + GetRef(function_node), function_node->params, function_node->body, + function_node->ret_type, function_node->type_params, + /* erase attributes */ DictAttrs(Map())); + new_function = WithAttr(std::move(new_function), attr::kExtern, Integer(1)); + output_mod->Update(kv.first, new_function); + } + } + } + + return output_mod; + }; + + return tvm::transform::CreateModulePass(pass_func, 0, "MarkCompilerFunctionsAsExtern", {}); +} + +TVM_REGISTER_GLOBAL("relay._transform.OutlineCompilerFunctionsWithExistingGlobalSymbols") + .set_body_typed(OutlineCompilerFunctionsWithExistingGlobalSymbols); +TVM_REGISTER_GLOBAL("relay._transform.MarkCompilerFunctionsAsExtern") + .set_body_typed(MarkCompilerFunctionsAsExtern); + +} // namespace transforms +} // namespace relay +} // namespace tvm diff --git a/src/relay/transforms/compiler_function_utils.h b/src/relay/transforms/compiler_function_utils.h new file mode 100644 index 000000000000..7b5143444bf8 --- /dev/null +++ b/src/relay/transforms/compiler_function_utils.h @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/transforms/compiler_function_utils.h + * \brief Helper passes for working with functions with the "Compiler" attribute. + * + * Those wishing to use the "RelayToTIR" custom pass machinery to do IRModule-at-a-time external + * codegen may find the following two helper passes useful: + * + * - \p OutlineCompilerFunctionsWithExistingGlobalSymbols will lift inline functions with a + * matching "Compiler" attribute to be global functions, using the "global_symbol" attribute + * already assigned. Can be used before custom lowering. + * + * Note that ideally "Compiler" attributed functions would be made global functions as early as + * possible and would stay that way. However, the GraphExecutorCodegen and AOTExecutorCodegen + * assume the entire model can be represented by a single 'main' function, and the Inline pass + * is run to respect that assumption. So this pass is mostly just to undo that Pass after modules + * have passed through the 'codegen' keyhole. + * + * See also OutlineCompilerFunctionsMutator in src/relay/backend/contrib/ethosu/codegen.cc. + * + * - (\p OutlineCompilerFunctions is a more general version of the above which can use a custom + * cache to both allocate "global_symbol" names and ensure two strucurally equal functions are + * assigned the same name, and thus lowered only once. This is used by Collage when preparing + * the optimally partitioned IRModule). + * + * - \p MarkCompilerFunctionsAsExtern will replace global functions with a matching "Compiler" + * attribute with the same function with just an "Extern" attribute, signalling the function + * has been dealt with. Calls to such functions will be rewritten to use the 'call_lowered' + * calling convention. Can be used after lowering to cleanup the IRModule. + * + * Note that the above behaviour is hard coded within the TECompiler, but is only available to + * external codegen using the Function-at-a-time "relay.ext.toolchain" extension point. + */ + +#ifndef TVM_RELAY_TRANSFORMS_COMPILER_FUNCTION_UTILS_H_ +#define TVM_RELAY_TRANSFORMS_COMPILER_FUNCTION_UTILS_H_ + +#include +#include +#include + +#include "tvm/ir/transform.h" +#include "tvm/relay/function.h" + +namespace tvm { +namespace relay { +namespace transforms { + +/*! + * \brief Abstract class representing a cache of unique global vars keyed by functions. This can + * be used to ensure structurally equal functions are assigned the same global var object, and + * thus lowered at most once. + */ +class GlobalSymbolCache { + public: + virtual GlobalVar GetGlobalSymbol(const Function& function) = 0; +}; + +/*! + * \brief A \p GlobalSymbolCache that requires every "Compiler" attributed function to already + * have a "global_symbol" attribute. + */ +class ExistingGlobalSymbolCache : public GlobalSymbolCache { + public: + ExistingGlobalSymbolCache() = default; + + GlobalVar GetGlobalSymbol(const Function& function) final; + + private: + /*! \brief Maps already seen global symbol names to their corresponding GlobalVar objects. */ + std::unordered_map global_vars_; +}; + +/*! + * \brief A pass to outline all literal functions in direct call positions which have a "Compiler" + * attribute. The given \p GlobalSymbolCache is used to determine a unique global symbol for each + * function, which is also assigned to the "global_symbol" attribute of the new global function. + * + * At most one function with the same global symbol is outlined. + * + * If \p compiler_filter is non-empty only functions with that as their attribute value are + * outlined. + */ +transform::Pass OutlineCompilerFunctions(std::shared_ptr cache, + std::string compiler_filter = ""); + +/*! + * \brief A pass to outline all literal functions in direct call positions which have a "Compiler" + * attribute. The functions are bound to unique global vars according to their existing + * "global_symbol" attribute. At most one function with the same global symbol is outlined. + * + * If \p compiler_filter is non-empty only functions with that as their attribute value are + * outlined. + * + * This pass may be useful for external codegen using the "RelayToTIR" custom pass mechanism + * to prepare the IRModule before custom lowering. + */ +transform::Pass OutlineCompilerFunctionsWithExistingGlobalSymbols(std::string compiler_filter = ""); + +/*! + * \brief A pass to mark all global functions which have a "Compiler" attribute matching + * compiler_filter as 'extern' by replacing all attributes with a single "Extern" attribute, and + * rewrite all calls to such functions to use the 'call_lowered' calling convention. + * + * If \p compiler_filter is non-empty only functions with that as their attribute value are + * outlined. + * + * This pass may be useful for external codegen using the "RelayToTIR" custom pass mechanism to + * cleanup the IRModule after custom lowering. + */ +transform::Pass MarkCompilerFunctionsAsExtern(std::string compiler_filter = ""); + +} // namespace transforms +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_TRANSFORMS_COMPILER_FUNCTION_UTILS_H_ diff --git a/src/relay/transforms/dead_code.cc b/src/relay/transforms/dead_code.cc index 45cb8271b074..18d2de1bdede 100644 --- a/src/relay/transforms/dead_code.cc +++ b/src/relay/transforms/dead_code.cc @@ -84,7 +84,7 @@ class PurityVisitor : ExprFunctor { for (const auto& kv : mod_->functions) { if (const auto* function_node = kv.second.as()) { if (function_node->HasNonzeroAttr(attr::kPrimitive) || - function_node->GetAttr(attr::kExternalSymbol)) { + function_node->HasNonzeroAttr(attr::kExtern)) { // Ignore primitive and external functions. continue; } @@ -133,9 +133,11 @@ class PurityVisitor : ExprFunctor { Purity VisitExpr_(const GlobalVarNode* global_var_node) final { auto global_var = GetRef(global_var_node); + ICHECK(mod_->ContainGlobalVar(global_var_node->name_hint)) + << "No definition for '" << global_var_node->name_hint << "'"; auto func = mod_->Lookup(global_var); if (const auto* function_node = func.as()) { - if (!function_node->GetAttr(attr::kExternalSymbol)) { + if (!function_node->HasNonzeroAttr(attr::kExtern)) { return VisitGlobalFunction(global_var, GetRef(function_node)); } } diff --git a/src/relay/transforms/inline.cc b/src/relay/transforms/inline.cc index c55b6778093e..012b3579494f 100644 --- a/src/relay/transforms/inline.cc +++ b/src/relay/transforms/inline.cc @@ -110,7 +110,7 @@ class Inliner : ExprMutator { if (!function_node->body.defined()) return false; // The function must be annotated with the inline attribute. - // (Note that external functions do not have this attribute!) + // (Note that partitioned functions and external functions do not have this attribute!) if (!function_node->HasNonzeroAttr(attr::kInline)) return false; // The function is not able to be inlined if any callee under the CallGraph @@ -136,8 +136,7 @@ class Inliner : ExprMutator { auto func = Function(fn->params, fn->body, fn->ret_type, fn->type_params, fn->attrs); // Inline the function body to the caller if this function uses default // compiler, i.e. no external codegen is needed. - if (!func->GetAttr(attr::kCompiler).defined() && - !func->GetAttr(attr::kExternalSymbol).defined()) { + if (!func->GetAttr(attr::kCompiler).defined() && !func->HasNonzeroAttr(attr::kExtern)) { ICHECK_EQ(func->params.size(), args.size()) << "Mismatch found in the number of parameters and call args"; // Bind the parameters with call args. diff --git a/tests/python/relay/transform/test_compiler_function_utils.py b/tests/python/relay/transform/test_compiler_function_utils.py new file mode 100644 index 000000000000..13e0f98e79f1 --- /dev/null +++ b/tests/python/relay/transform/test_compiler_function_utils.py @@ -0,0 +1,162 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License +"""Unit tests for the OutlineCompilerFunctionsWithExistingGlobalSymbols and + MarkCompilerFunctionsAsExtern external codegen helper passes.""" + +import tvm +import tvm.testing +import numpy as np + + +def make_const(dtype, shape): + return tvm.relay.const(np.random.rand(*shape).astype(dtype)) + + +def make_consts(dtype, shapes): + return [make_const(dtype, shape) for shape in shapes] + + +metatable = { + "relay.Constant": make_consts( + "float16", + [ + (2304, 768), # 0 + (2304,), # 1 + (600, 32, 64), # 2 + ], + ), + "attributes": [{"relay_attrs": None}], +} + + +def inlined_mod(): + return tvm.parser.parse( + """ + #[version = "0.0.5"] + def @main(%x0 : Tensor[(1600, 768), float16], %x3 : Tensor[(600, 32, 64), float16]) -> (Tensor[(1600, 2304), float16], Tensor[(600, 32, 32), float16]) { + %0 = fn(%y_0_i0: Tensor[(1600, 768), float16], %y_0_i1: Tensor[(2304, 768), float16], %y_0_i2: Tensor[(2304), float16], + Inline=1, Compiler="cutlass", global_symbol="tvmgen_default_cutlass_main_0", Primitive=1) -> Tensor[(1600, 2304), float16] { + %4 = fn (%FunctionVar_0_0: Tensor[(1600, 768), float16], %FunctionVar_0_1: Tensor[(2304, 768), float16], %FunctionVar_0_2: Tensor[(2304), float16], + PartitionedFromPattern="nn.dense_add_", Composite="cutlass.dense_bias") -> Tensor[(1600, 2304), float16] { + %5 = nn.dense(%FunctionVar_0_0, %FunctionVar_0_1, units=2304); + add(%5, %FunctionVar_0_2) + }; + %4(%y_0_i0, %y_0_i1, %y_0_i2) + }; + %1 = %0(%x0, meta[relay.Constant][0], meta[relay.Constant][1]); + %2 = fn(%y_3_i0: Tensor[(600, 32, 64), float16], %y_3_i1: Tensor[(600, 32, 64), float16], + Inline=1, Compiler="cublas", global_symbol="tvmgen_default_cublas_main_3", Primitive=1) -> Tensor[(600, 32, 32), float16] { + %6 = fn (%FunctionVar_0_01: Tensor[(600, 32, 64), float16], %FunctionVar_0_11: Tensor[(600, 32, 64), float16], + PartitionedFromPattern="nn.batch_matmul_", Composite="cublas.batch_matmul") -> Tensor[(600, 32, 32), float16] { + nn.batch_matmul(%FunctionVar_0_01, %FunctionVar_0_11, out_dtype="float16", transpose_b=True) + }; + %6(%y_3_i0, %y_3_i1) + }; + %3 = %2(%x3, meta[relay.Constant][2]); + (%1, %3) + } + """, + "from_string", + None, + metatable, + ) + + +def expected_outlined_mod(): + return tvm.parser.parse( + """ + #[version = "0.0.5"] + def @main(%x0 : Tensor[(1600, 768), float16], %x3 : Tensor[(600, 32, 64), float16]) -> (Tensor[(1600, 2304), float16], Tensor[(600, 32, 32), float16]) { + %1 = @tvmgen_default_cutlass_main_0(%x0, meta[relay.Constant][0], meta[relay.Constant][1]); + %2 = fn(%y_3_i0: Tensor[(600, 32, 64), float16], %y_3_i1: Tensor[(600, 32, 64), float16], + Inline=1, Compiler="cublas", global_symbol="tvmgen_default_cublas_main_3", Primitive=1) -> Tensor[(600, 32, 32), float16] { + %6 = fn (%FunctionVar_0_01: Tensor[(600, 32, 64), float16], %FunctionVar_0_11: Tensor[(600, 32, 64), float16], + PartitionedFromPattern="nn.batch_matmul_", Composite="cublas.batch_matmul") -> Tensor[(600, 32, 32), float16] { + nn.batch_matmul(%FunctionVar_0_01, %FunctionVar_0_11, out_dtype="float16", transpose_b=True) + }; + %6(%y_3_i0, %y_3_i1) + }; + %3 = %2(%x3, meta[relay.Constant][2]); + (%1, %3) + } + + def @tvmgen_default_cutlass_main_0(%y_0_i0: Tensor[(1600, 768), float16], %y_0_i1: Tensor[(2304, 768), float16], %y_0_i2: Tensor[(2304), float16], + Inline=1, Compiler="cutlass", global_symbol="tvmgen_default_cutlass_main_0", Primitive=1) -> Tensor[(1600, 2304), float16] { + %4 = fn (%FunctionVar_0_0: Tensor[(1600, 768), float16], %FunctionVar_0_1: Tensor[(2304, 768), float16], %FunctionVar_0_2: Tensor[(2304), float16], + PartitionedFromPattern="nn.dense_add_", Composite="cutlass.dense_bias") -> Tensor[(1600, 2304), float16] { + %5 = nn.dense(%FunctionVar_0_0, %FunctionVar_0_1, units=2304); + add(%5, %FunctionVar_0_2) + }; + %4(%y_0_i0, %y_0_i1, %y_0_i2) + } + """, + "from_string", + None, + metatable, + ) + + +def expected_extern_mod(): + return tvm.parser.parse( + """ + #[version = "0.0.5"] + def @main(%x0 : Tensor[(1600, 768), float16], %x3 : Tensor[(600, 32, 64), float16]) -> (Tensor[(1600, 2304), float16], Tensor[(600, 32, 32), float16]) { + %1 = call_lowered(@tvmgen_default_cutlass_main_0, (%x0, meta[relay.Constant][0], meta[relay.Constant][1]), metadata=meta[attributes][0]); + %2 = fn(%y_3_i0: Tensor[(600, 32, 64), float16], %y_3_i1: Tensor[(600, 32, 64), float16], + Inline=1, Compiler="cublas", global_symbol="tvmgen_default_cublas_main_3", Primitive=1) -> Tensor[(600, 32, 32), float16] { + %6 = fn (%FunctionVar_0_01: Tensor[(600, 32, 64), float16], %FunctionVar_0_11: Tensor[(600, 32, 64), float16], + PartitionedFromPattern="nn.batch_matmul_", Composite="cublas.batch_matmul") -> Tensor[(600, 32, 32), float16] { + nn.batch_matmul(%FunctionVar_0_01, %FunctionVar_0_11, out_dtype="float16", transpose_b=True) + }; + %6(%y_3_i0, %y_3_i1) + }; + %3 = %2(%x3, meta[relay.Constant][2]); + (%1, %3) + } + + def @tvmgen_default_cutlass_main_0(%y_0_i0: Tensor[(1600, 768), float16], %y_0_i1: Tensor[(2304, 768), float16], %y_0_i2: Tensor[(2304), float16], + Extern=1) -> Tensor[(1600, 2304), float16] { + %4 = fn (%FunctionVar_0_0: Tensor[(1600, 768), float16], %FunctionVar_0_1: Tensor[(2304, 768), float16], %FunctionVar_0_2: Tensor[(2304), float16], + PartitionedFromPattern="nn.dense_add_", Composite="cutlass.dense_bias") -> Tensor[(1600, 2304), float16] { + %5 = nn.dense(%FunctionVar_0_0, %FunctionVar_0_1, units=2304); + add(%5, %FunctionVar_0_2) + }; + %4(%y_0_i0, %y_0_i1, %y_0_i2) + } + """, + "from_string", + None, + metatable, + ) + + +def test_outline_compiler_functions_with_existing_global_symbols(): + actual_outlined_mod = tvm.relay.transform.OutlineCompilerFunctionsWithExistingGlobalSymbols( + "cutlass" + )(inlined_mod()) + tvm.ir.assert_structural_equal(actual_outlined_mod, expected_outlined_mod(), map_free_vars=True) + + +def test_mark_compiler_functions_as_extern(): + actual_extern_mod = tvm.relay.transform.MarkCompilerFunctionsAsExtern("cutlass")( + expected_outlined_mod() + ) + tvm.ir.assert_structural_equal(actual_extern_mod, expected_extern_mod(), map_free_vars=True) + + +if __name__ == "__main__": + tvm.testing.main()