From 123fcc9b6465d954a18996347da12791cb8c5768 Mon Sep 17 00:00:00 2001 From: mbs-octoml Date: Wed, 25 May 2022 17:15:46 -0700 Subject: [PATCH 1/2] [BYOC] Two helper passes for external codegen using RelayToTIR custom pass machinery (See https://discuss.tvm.apache.org/t/byoc-supporting-cutlass-byoc-with-collage/12796/6 for context, which in turn is part of Collage (https://github.com/apache/tvm-rfcs/blob/main/rfcs/0062-collage.md). For reasons explained in the above thread I'm moving CUTLASS to be IRModule-at-a-time external codegen using a custom RelayToTIR pass instead of the traditional function-at-a-time external codegen using a relay.ext.cutlass registered function. This means some of the rewriing done on-the-fly by LowerTEPass now needs to be done by the custom pass directly. This PR supplies two passes which ease that burden: - Before starting the CUTLASS-specific processing, make sure all "Compiler" attributed functions have unique global definitions (ie are outlined). Though functions start in this form after BYOC partitioning, under Graph and AOT compilation flows those functions are then inlined to pass through the 'codegen' keyhole which assumes the whole model is just one self-contained main function. This pass will undo that. (I gave up trying to just remove the inlining in the first place.) - After the CUTLASS-specific processing the now compiled "Compiler" attributed functions need to marked as 'extern'. The te_compiler.cc uses the "ExternalSymbol" attribute for that, but since a) the symbol name is never needed, on the presense of the attribute is significant downstream and b) "ExternalSymbol" is easy to confuse with "global_symbol", I just replaced "ExternalSymbol" with "Extern" with an Integer(1) (cf "Primitive"). The outlining pass is a little more general than necessary because it (will also) be used by Collage to rewrite the IRModule into optimally partitioned form while making maximal reuse of partition functions. Hence the abstract GlobalSymbolCache. --- include/tvm/ir/expr.h | 3 +- include/tvm/relay/attrs/call.h | 2 +- include/tvm/relay/function.h | 24 ++- python/tvm/relay/transform/transform.py | 66 ++++-- src/ir/expr.cc | 3 +- src/parser/tokenizer.h | 4 +- src/relay/backend/te_compiler.cc | 6 +- src/relay/backend/vm/compiler.cc | 4 +- src/relay/ir/function.cc | 2 +- src/relay/op/nn/nn.cc | 1 + .../transforms/compiler_function_utils.cc | 201 ++++++++++++++++++ .../transforms/compiler_function_utils.h | 135 ++++++++++++ src/relay/transforms/dead_code.cc | 6 +- src/relay/transforms/inline.cc | 5 +- .../transform/test_compiler_function_utils.py | 162 ++++++++++++++ 15 files changed, 585 insertions(+), 39 deletions(-) create mode 100644 src/relay/transforms/compiler_function_utils.cc create mode 100644 src/relay/transforms/compiler_function_utils.h create mode 100644 tests/python/relay/transform/test_compiler_function_utils.py diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index 4a00de802c61..b54a067e1c94 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -260,9 +260,10 @@ class GlobalVarNode : public RelayExprNode { */ class GlobalVar : public RelayExpr { public: - TVM_DLL explicit GlobalVar(String name_hint, Type type = {}); + TVM_DLL explicit GlobalVar(String name_hint, Type type = {}, Span span = {}); TVM_DEFINE_OBJECT_REF_METHODS(GlobalVar, RelayExpr, GlobalVarNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(GlobalVarNode); }; // PrimExprs that are useful as runtime containers. diff --git a/include/tvm/relay/attrs/call.h b/include/tvm/relay/attrs/call.h index 167a593ff377..e0b347de1783 100644 --- a/include/tvm/relay/attrs/call.h +++ b/include/tvm/relay/attrs/call.h @@ -35,7 +35,7 @@ namespace relay { * \brief Metadata for calls to TIR functions, useful for program analysis crossing Relay and TIR. */ struct CallLoweredAttrs : public tvm::AttrsNode { - /*! \brief The metadata attached to the call node. */ + /*! \brief Additional metadata attached to the call node. Should be replaced by explict fields. */ Map metadata; TVM_DECLARE_ATTRS(CallLoweredAttrs, "relay.attrs.CallLoweredAttrs") { diff --git a/include/tvm/relay/function.h b/include/tvm/relay/function.h index 5869f878aa85..29ce07753398 100644 --- a/include/tvm/relay/function.h +++ b/include/tvm/relay/function.h @@ -170,19 +170,34 @@ const FunctionNode* AsOptimizableFunctionNode(const BaseFunc& base_func); * \brief namespace of the attributes that can be attached to a relay::Function. */ namespace attr { -/*! \brief Mark the function as a primitive function. */ + +/*! + * \brief Mark the function as a primitive function. Should be bound to \p Integer(1). + * + * Type: Integer + */ constexpr const char* kPrimitive = "Primitive"; + +/*! + * \brief Mark the function as being 'extern', ie implemented in a runtime::Module. Should be bound + * to \p Integer(1). Typically accompanied by "Primitive". + * + * Type: Integer + */ +constexpr const char* kExtern = "Extern"; + /*! - * \brief Indicate the compiler that should be used for building this function. + * \brief Indicate the external codegen 'compiler' that should be used for building this function. * When this is unset or set to "default", the default compilation pipeline will be used. + * + * Type: String */ constexpr const char* kCompiler = "Compiler"; + /*! \brief Indicate if the function is a closure. */ constexpr const char* kClosure = "Closure"; /*! \brief Store a Var to parameter/Constant mapping on a Function. */ constexpr const char* kParams = "__params__"; -/*! \brief Store the unique external symbol for external compilers. */ -constexpr const char* kExternalSymbol = "ExternalSymbol"; /*! \brief Mark if the function should be avoided being optimized. */ constexpr const char* kSkipOptimization = "SkipOptimization"; /*! \brief Treat the function as a composite operator. */ @@ -193,6 +208,7 @@ constexpr const char* kInline = "Inline"; constexpr const char* kPartitionedFromPattern = "PartitionedFromPattern"; /*! \brief Mark the function as only composed of reshape operations. */ constexpr const char* kReshapeOnly = "relay.reshape_only"; + } // namespace attr } // namespace relay diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index 9f253f8e88ba..4ee1f070c541 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -802,24 +802,6 @@ def Inline(): return _ffi_api.Inline() -def InlineComposites(target): - """Perform inlining on the given Relay IR module. The functions originate - from the MergeComposite pass based on an input pattern table will fold back - to main. Currently, this is used for the TRT BYOC which expects a single - primitive function to operate on. - - Parameters - ---------- - target: str - The byoc target for which ops need to fold back to primitive function. - Returns - ------- - ret: tvm.transform.Pass - The registered pass that performs inlining for a Relay IR module. - """ - return _ffi_api.InlineComposites(target) - - def gradient(expr, mod=None, mode="higher_order"): """ Transform the input function, @@ -1386,3 +1368,51 @@ def SplitArgs(max_function_args): The registered pass for constant folding. """ return _ffi_api.SplitArgs(max_function_args) + + +def OutlineCompilerFunctionsWithExistingGlobalSymbols(compiler_filter=""): + """A pass to outline all literal functions in direct call positions which have a "Compiler" + attribute. The functions are bound to unique global vars according to their existing + "global_symbol" attribute. At most one function with the same global symbol is outlined. + + If compiler_filter is non-empty only functions with that as their attribute value are + outlined. + + This pass may be useful for external codegen using the "RelayToTIR" custom pass mechanism + to prepare the IRModule before custom lowering. + + Parameters + ---------- + compiler_filter : String + If non-empty, the 'compiler' attribute to filter on. + + Returns + ------- + ret : tvm.transform.Pass + The pass. + """ + return _ffi_api.OutlineCompilerFunctionsWithExistingGlobalSymbols(compiler_filter) + + +def MarkCompilerFunctionsAsExtern(compiler_filter=""): + """A pass to mark all global functions which have a "Compiler" attribute matching + compiler_filter as 'extern' by replacing all attributes with a single "Extern" attribute, and + rewrite all calls to such functions to use the 'call_lowered' calling convention. + + If compiler_filter is non-empty only functions with that as their attribute value are + outlined. + + This pass may be useful for external codegen using the "RelayToTIR" custom pass mechanism to + cleanup the IRModule after custom lowering. + + Parameters + ---------- + compiler_filter : String + If non-empty, the 'compiler' attribute to filter on. + + Returns + ------- + ret : tvm.transform.Pass + The pass. + """ + return _ffi_api.MarkCompilerFunctionsAsExtern(compiler_filter) diff --git a/src/ir/expr.cc b/src/ir/expr.cc index 399873492f04..a3318bf94fc6 100644 --- a/src/ir/expr.cc +++ b/src/ir/expr.cc @@ -141,10 +141,11 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "range(min=" << op->min << ", ext=" << op->extent << ')'; }); -GlobalVar::GlobalVar(String name_hint, Type type) { +GlobalVar::GlobalVar(String name_hint, Type type, Span span) { ObjectPtr n = make_object(); n->name_hint = std::move(name_hint); n->checked_type_ = std::move(type); + n->span = std::move(span); data_ = std::move(n); } diff --git a/src/parser/tokenizer.h b/src/parser/tokenizer.h index 4ac1ceef26dc..505784e4bf70 100644 --- a/src/parser/tokenizer.h +++ b/src/parser/tokenizer.h @@ -295,8 +295,6 @@ struct Tokenizer { int line = this->line; int column = this->col; - ICHECK_EQ(Peek(), '['); - Next(); std::stringstream type_key; while (More() && Peek() != ']') { type_key << Next(); @@ -498,7 +496,7 @@ struct Tokenizer { auto token = NewToken(TokenType::kQuestion); Next(); return token; - } else if (MatchString("meta")) { + } else if (MatchString("meta[")) { return TokenizeMetaRef(); } else if (next == '#') { return TokenizeAttr(); diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc index 73b44f7361a5..344bfde3d90c 100644 --- a/src/relay/backend/te_compiler.cc +++ b/src/relay/backend/te_compiler.cc @@ -168,7 +168,7 @@ class TECompilerImpl : public TECompilerNode { if (const auto* function_node = kv2.second.as()) { // Abandon the existing function annotations. - // Unfortuantely, Optional() is indistinguishable from + // Unfortunately, Optional() is indistinguishable from // NullValue(), and DictAttrs() is nullptr, so to erase the attributes, we // need pass in DictAttrs()), which is a DictAttrs containing no // attributes. @@ -177,7 +177,7 @@ class TECompilerImpl : public TECompilerNode { function_node->body, function_node->ret_type, function_node->type_params, /* erase attributes */ DictAttrs(Map())); // Mark function as 'extern' using the "ExternalSymbol" attribute. - function = WithAttr(std::move(function), attr::kExternalSymbol, kv2.first->name_hint); + function = WithAttr(std::move(function), attr::kExtern, Integer(1)); module->Add(kv2.first, function); } } @@ -688,7 +688,7 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator { Expr DeviceAwareVisitExpr_(const FunctionNode* function_node) override { if (function_node->HasNonzeroAttr(attr::kPrimitive) || - function_node->GetAttr(attr::kExternalSymbol)) { + function_node->HasNonzeroAttr(attr::kExtern)) { // Nothing to lower inside primitive/external functions. return GetRef(function_node); } else { diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index e0b742a84090..d9730b1b5a4c 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -922,7 +922,7 @@ void VMCompiler::LowerImpl(IRModule mod) { for (const auto& pair : context_.module->functions) { auto gvar = pair.first; if (auto* n = pair.second.as()) { - if (n->GetAttr(attr::kExternalSymbol).defined()) { + if (n->HasNonzeroAttr(attr::kExtern)) { // Already compiled during lowering. continue; } @@ -1131,7 +1131,7 @@ size_t VMCompiler::PopulateGlobalMap() { // Excludes PrimFuncs and externs, which are managed by the primitive_map_. for (const auto& kv : context_.module->functions) { if (const auto* function_node = kv.second.as()) { - if (!function_node->GetAttr(attr::kExternalSymbol)) { + if (!function_node->HasNonzeroAttr(attr::kExtern)) { context_.global_map.emplace(kv.first, context_.global_map.size()); } } diff --git a/src/relay/ir/function.cc b/src/relay/ir/function.cc index bf0dd577a4d2..63e74144e061 100644 --- a/src/relay/ir/function.cc +++ b/src/relay/ir/function.cc @@ -112,7 +112,7 @@ FuncType FunctionNode::func_type_annotation() const { const FunctionNode* AsOptimizableFunctionNode(const BaseFunc& base_func) { if (const auto* function_node = base_func.as()) { if (!function_node->GetAttr(attr::kCompiler).defined() && - !function_node->GetAttr(attr::kExternalSymbol).defined() && + !function_node->HasNonzeroAttr(attr::kExtern) && !function_node->HasNonzeroAttr(attr::kSkipOptimization)) { return function_node; } diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc index 234cafdca150..41b47401de1c 100644 --- a/src/relay/op/nn/nn.cc +++ b/src/relay/op/nn/nn.cc @@ -1012,6 +1012,7 @@ Both `tensor_a` and `tensor_b` can be transposed. For legacy reason, we use NT f - **out**: `(b, m, n)`. )code" TVM_ADD_FILELINE) + .set_attrs_type() .set_num_inputs(2) .add_argument("tensor_a", "3D Tensor", "The first input.") .add_argument("tensor_b", "3D Tensor", "The second input.") diff --git a/src/relay/transforms/compiler_function_utils.cc b/src/relay/transforms/compiler_function_utils.cc new file mode 100644 index 000000000000..9d956003df00 --- /dev/null +++ b/src/relay/transforms/compiler_function_utils.cc @@ -0,0 +1,201 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/transforms/compiler_function_utils.cc + * \brief Helper passes for working with functions with the "Compiler" attribute. + */ + +#include "./compiler_function_utils.h" + +#include "../op/call/call.h" +#include "tvm/relay/analysis.h" +#include "tvm/relay/expr_functor.h" + +namespace tvm { +namespace relay { +namespace transforms { +namespace { + +/*! + * \brief Rewrite calls to inlined "Compiler" functions to global functions. The given + * module will be extended with the newly outlined functions. + */ +class Outliner : public MixedModeMutator { + public: + Outliner(GlobalSymbolCache* cache, std::string compiler_filter, IRModule mod) + : cache_(cache), compiler_filter_(std::move(compiler_filter)), mod_(std::move(mod)) {} + + Expr Rewrite_(const CallNode* pre, const Expr& post) final { + Call new_call = Downcast(post); + if (const auto* function_node = new_call->op.as()) { + Optional opt_compiler = function_node->GetAttr(attr::kCompiler); + if (opt_compiler.defined() && + (compiler_filter_.empty() || opt_compiler.value() == compiler_filter_)) { + auto function = GetRef(function_node); + ICHECK(FreeVars(function).empty()) << "Function marked with '" << attr::kCompiler + << "' attribute should not have free variables"; + GlobalVar global_symbol = cache_->GetGlobalSymbol(function); + // Depending on the cache's implementation, two structurally equal (but not object equal) + // functions may be assigned the same global symbol. If so we'll lift it just once, but + // rewrite all the calls. + if (!mod_->ContainGlobalVar(global_symbol->name_hint)) { + function = + WithAttr(std::move(function), tvm::attr::kGlobalSymbol, global_symbol->name_hint); + mod_->Add(global_symbol, function); + } + return WithFields(new_call, global_symbol); + } + } + return post; + } + + private: + GlobalSymbolCache* cache_; + std::string compiler_filter_; + IRModule mod_; +}; + +/*! + * \brief Rewrite calls to global "Compiler" functions to use the 'call_lowered' convention. + */ +class CallRewriter : public MixedModeMutator { + public: + CallRewriter(std::string compiler_filter, IRModule mod) + : compiler_filter_(std::move(compiler_filter)), mod_(std::move(mod)) {} + + Expr Rewrite_(const CallNode* pre, const Expr& post) final { + Call new_call = Downcast(post); + if (const auto* global_var_node = new_call->op.as()) { + if (const auto* function_node = + mod_->Lookup(GetRef(global_var_node)).as()) { + Optional opt_compiler = function_node->GetAttr(attr::kCompiler); + if (opt_compiler.defined() && + (compiler_filter_.empty() || opt_compiler.value() == compiler_filter_)) { + Optional opt_global_symbol = + function_node->GetAttr(tvm::attr::kGlobalSymbol); + ICHECK(opt_global_symbol.defined()); + GlobalVar global_symbol = mod_->GetGlobalVar(opt_global_symbol.value()); + CallLoweredAttrs attrs; + attrs.metadata.Set("relay_attrs", new_call->attrs); + return CallLowered(global_symbol, new_call->args, attrs, new_call->span); + } + } + } + return post; + } + + private: + std::string compiler_filter_; + IRModule mod_; +}; + +} // namespace + +GlobalVar ExistingGlobalSymbolCache::GetGlobalSymbol(const Function& function) { + Optional opt_global_symbol = function->GetAttr(tvm::attr::kGlobalSymbol); + ICHECK(opt_global_symbol.defined()) + << "ExistingGlobalSymbolCache requires all functions to already have a '" + << tvm::attr::kGlobalSymbol << "' attribute"; + std::string global_symbol = opt_global_symbol.value(); + auto itr = global_vars_.find(global_symbol); + if (itr != global_vars_.end()) { + return itr->second; + } + // Ok if function does not have a checked_type, but if it does capture it in the global var. + GlobalVar global_var(global_symbol, function->checked_type_, function->span); + global_vars_.emplace(global_symbol, global_var); + return global_var; +} + +transform::Pass OutlineCompilerFunctions(std::shared_ptr cache, + std::string compiler_filter) { + runtime::TypedPackedFunc pass_func = + [cache = std::move(cache), compiler_filter = std::move(compiler_filter)]( + IRModule mod, transform::PassContext ctx) { + IRModule output_mod = GetRef(mod.CopyOnWrite()); + for (const auto& kv : mod->functions) { + const FunctionNode* function_node = AsOptimizableFunctionNode(kv.second); + if (function_node) { + Expr new_body = + Outliner(cache.get(), compiler_filter, output_mod).VisitExpr(function_node->body); + Function new_function = + WithFields(GetRef(function_node), /*opt_params=*/{}, new_body); + output_mod->Add(kv.first, new_function); + } + } + return output_mod; + }; + + return tvm::transform::CreateModulePass(pass_func, 0, "OutlineCompilerFunctions", {}); +} + +// Any Java programmers in the house? +transform::Pass OutlineCompilerFunctionsWithExistingGlobalSymbols(std::string compiler_filter) { + return OutlineCompilerFunctions(std::make_shared(), + std::move(compiler_filter)); +} + +transform::Pass MarkCompilerFunctionsAsExtern(std::string compiler_filter) { + runtime::TypedPackedFunc pass_func = + [compiler_filter = std::move(compiler_filter)](IRModule mod, transform::PassContext ctx) { + IRModule output_mod = mod->ShallowCopy(); + + // First pass, rewrite the calls. + // We have to do this before marking functions as 'extern' to know which calls to rewrite! + for (const auto& kv : mod->functions) { + if (const auto* function_node = AsOptimizableFunctionNode(kv.second)) { + Expr new_body = + CallRewriter(compiler_filter, output_mod).VisitExpr(function_node->body); + Function new_function = + WithFields(GetRef(function_node), /*opt_params=*/{}, new_body); + output_mod->Update(kv.first, new_function); + } + } + + // Second pass, mark functions as 'extern'. + for (const auto& kv : mod->functions) { + if (const auto* function_node = kv.second.as()) { + Optional opt_compiler = function_node->GetAttr(attr::kCompiler); + if (opt_compiler.defined() && + (compiler_filter.empty() || opt_compiler.value() == compiler_filter)) { + auto new_function = WithFields( + GetRef(function_node), function_node->params, function_node->body, + function_node->ret_type, function_node->type_params, + /* erase attributes */ DictAttrs(Map())); + new_function = WithAttr(std::move(new_function), attr::kExtern, Integer(1)); + output_mod->Update(kv.first, new_function); + } + } + } + + return output_mod; + }; + + return tvm::transform::CreateModulePass(pass_func, 0, "MarkCompilerFunctionsAsExtern", {}); +} + +TVM_REGISTER_GLOBAL("relay._transform.OutlineCompilerFunctionsWithExistingGlobalSymbols") + .set_body_typed(OutlineCompilerFunctionsWithExistingGlobalSymbols); +TVM_REGISTER_GLOBAL("relay._transform.MarkCompilerFunctionsAsExtern") + .set_body_typed(MarkCompilerFunctionsAsExtern); + +} // namespace transforms +} // namespace relay +} // namespace tvm diff --git a/src/relay/transforms/compiler_function_utils.h b/src/relay/transforms/compiler_function_utils.h new file mode 100644 index 000000000000..7b5143444bf8 --- /dev/null +++ b/src/relay/transforms/compiler_function_utils.h @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/transforms/compiler_function_utils.h + * \brief Helper passes for working with functions with the "Compiler" attribute. + * + * Those wishing to use the "RelayToTIR" custom pass machinery to do IRModule-at-a-time external + * codegen may find the following two helper passes useful: + * + * - \p OutlineCompilerFunctionsWithExistingGlobalSymbols will lift inline functions with a + * matching "Compiler" attribute to be global functions, using the "global_symbol" attribute + * already assigned. Can be used before custom lowering. + * + * Note that ideally "Compiler" attributed functions would be made global functions as early as + * possible and would stay that way. However, the GraphExecutorCodegen and AOTExecutorCodegen + * assume the entire model can be represented by a single 'main' function, and the Inline pass + * is run to respect that assumption. So this pass is mostly just to undo that Pass after modules + * have passed through the 'codegen' keyhole. + * + * See also OutlineCompilerFunctionsMutator in src/relay/backend/contrib/ethosu/codegen.cc. + * + * - (\p OutlineCompilerFunctions is a more general version of the above which can use a custom + * cache to both allocate "global_symbol" names and ensure two strucurally equal functions are + * assigned the same name, and thus lowered only once. This is used by Collage when preparing + * the optimally partitioned IRModule). + * + * - \p MarkCompilerFunctionsAsExtern will replace global functions with a matching "Compiler" + * attribute with the same function with just an "Extern" attribute, signalling the function + * has been dealt with. Calls to such functions will be rewritten to use the 'call_lowered' + * calling convention. Can be used after lowering to cleanup the IRModule. + * + * Note that the above behaviour is hard coded within the TECompiler, but is only available to + * external codegen using the Function-at-a-time "relay.ext.toolchain" extension point. + */ + +#ifndef TVM_RELAY_TRANSFORMS_COMPILER_FUNCTION_UTILS_H_ +#define TVM_RELAY_TRANSFORMS_COMPILER_FUNCTION_UTILS_H_ + +#include +#include +#include + +#include "tvm/ir/transform.h" +#include "tvm/relay/function.h" + +namespace tvm { +namespace relay { +namespace transforms { + +/*! + * \brief Abstract class representing a cache of unique global vars keyed by functions. This can + * be used to ensure structurally equal functions are assigned the same global var object, and + * thus lowered at most once. + */ +class GlobalSymbolCache { + public: + virtual GlobalVar GetGlobalSymbol(const Function& function) = 0; +}; + +/*! + * \brief A \p GlobalSymbolCache that requires every "Compiler" attributed function to already + * have a "global_symbol" attribute. + */ +class ExistingGlobalSymbolCache : public GlobalSymbolCache { + public: + ExistingGlobalSymbolCache() = default; + + GlobalVar GetGlobalSymbol(const Function& function) final; + + private: + /*! \brief Maps already seen global symbol names to their corresponding GlobalVar objects. */ + std::unordered_map global_vars_; +}; + +/*! + * \brief A pass to outline all literal functions in direct call positions which have a "Compiler" + * attribute. The given \p GlobalSymbolCache is used to determine a unique global symbol for each + * function, which is also assigned to the "global_symbol" attribute of the new global function. + * + * At most one function with the same global symbol is outlined. + * + * If \p compiler_filter is non-empty only functions with that as their attribute value are + * outlined. + */ +transform::Pass OutlineCompilerFunctions(std::shared_ptr cache, + std::string compiler_filter = ""); + +/*! + * \brief A pass to outline all literal functions in direct call positions which have a "Compiler" + * attribute. The functions are bound to unique global vars according to their existing + * "global_symbol" attribute. At most one function with the same global symbol is outlined. + * + * If \p compiler_filter is non-empty only functions with that as their attribute value are + * outlined. + * + * This pass may be useful for external codegen using the "RelayToTIR" custom pass mechanism + * to prepare the IRModule before custom lowering. + */ +transform::Pass OutlineCompilerFunctionsWithExistingGlobalSymbols(std::string compiler_filter = ""); + +/*! + * \brief A pass to mark all global functions which have a "Compiler" attribute matching + * compiler_filter as 'extern' by replacing all attributes with a single "Extern" attribute, and + * rewrite all calls to such functions to use the 'call_lowered' calling convention. + * + * If \p compiler_filter is non-empty only functions with that as their attribute value are + * outlined. + * + * This pass may be useful for external codegen using the "RelayToTIR" custom pass mechanism to + * cleanup the IRModule after custom lowering. + */ +transform::Pass MarkCompilerFunctionsAsExtern(std::string compiler_filter = ""); + +} // namespace transforms +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_TRANSFORMS_COMPILER_FUNCTION_UTILS_H_ diff --git a/src/relay/transforms/dead_code.cc b/src/relay/transforms/dead_code.cc index 45cb8271b074..18d2de1bdede 100644 --- a/src/relay/transforms/dead_code.cc +++ b/src/relay/transforms/dead_code.cc @@ -84,7 +84,7 @@ class PurityVisitor : ExprFunctor { for (const auto& kv : mod_->functions) { if (const auto* function_node = kv.second.as()) { if (function_node->HasNonzeroAttr(attr::kPrimitive) || - function_node->GetAttr(attr::kExternalSymbol)) { + function_node->HasNonzeroAttr(attr::kExtern)) { // Ignore primitive and external functions. continue; } @@ -133,9 +133,11 @@ class PurityVisitor : ExprFunctor { Purity VisitExpr_(const GlobalVarNode* global_var_node) final { auto global_var = GetRef(global_var_node); + ICHECK(mod_->ContainGlobalVar(global_var_node->name_hint)) + << "No definition for '" << global_var_node->name_hint << "'"; auto func = mod_->Lookup(global_var); if (const auto* function_node = func.as()) { - if (!function_node->GetAttr(attr::kExternalSymbol)) { + if (!function_node->HasNonzeroAttr(attr::kExtern)) { return VisitGlobalFunction(global_var, GetRef(function_node)); } } diff --git a/src/relay/transforms/inline.cc b/src/relay/transforms/inline.cc index c55b6778093e..012b3579494f 100644 --- a/src/relay/transforms/inline.cc +++ b/src/relay/transforms/inline.cc @@ -110,7 +110,7 @@ class Inliner : ExprMutator { if (!function_node->body.defined()) return false; // The function must be annotated with the inline attribute. - // (Note that external functions do not have this attribute!) + // (Note that partitioned functions and external functions do not have this attribute!) if (!function_node->HasNonzeroAttr(attr::kInline)) return false; // The function is not able to be inlined if any callee under the CallGraph @@ -136,8 +136,7 @@ class Inliner : ExprMutator { auto func = Function(fn->params, fn->body, fn->ret_type, fn->type_params, fn->attrs); // Inline the function body to the caller if this function uses default // compiler, i.e. no external codegen is needed. - if (!func->GetAttr(attr::kCompiler).defined() && - !func->GetAttr(attr::kExternalSymbol).defined()) { + if (!func->GetAttr(attr::kCompiler).defined() && !func->HasNonzeroAttr(attr::kExtern)) { ICHECK_EQ(func->params.size(), args.size()) << "Mismatch found in the number of parameters and call args"; // Bind the parameters with call args. diff --git a/tests/python/relay/transform/test_compiler_function_utils.py b/tests/python/relay/transform/test_compiler_function_utils.py new file mode 100644 index 000000000000..13e0f98e79f1 --- /dev/null +++ b/tests/python/relay/transform/test_compiler_function_utils.py @@ -0,0 +1,162 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License +"""Unit tests for the OutlineCompilerFunctionsWithExistingGlobalSymbols and + MarkCompilerFunctionsAsExtern external codegen helper passes.""" + +import tvm +import tvm.testing +import numpy as np + + +def make_const(dtype, shape): + return tvm.relay.const(np.random.rand(*shape).astype(dtype)) + + +def make_consts(dtype, shapes): + return [make_const(dtype, shape) for shape in shapes] + + +metatable = { + "relay.Constant": make_consts( + "float16", + [ + (2304, 768), # 0 + (2304,), # 1 + (600, 32, 64), # 2 + ], + ), + "attributes": [{"relay_attrs": None}], +} + + +def inlined_mod(): + return tvm.parser.parse( + """ + #[version = "0.0.5"] + def @main(%x0 : Tensor[(1600, 768), float16], %x3 : Tensor[(600, 32, 64), float16]) -> (Tensor[(1600, 2304), float16], Tensor[(600, 32, 32), float16]) { + %0 = fn(%y_0_i0: Tensor[(1600, 768), float16], %y_0_i1: Tensor[(2304, 768), float16], %y_0_i2: Tensor[(2304), float16], + Inline=1, Compiler="cutlass", global_symbol="tvmgen_default_cutlass_main_0", Primitive=1) -> Tensor[(1600, 2304), float16] { + %4 = fn (%FunctionVar_0_0: Tensor[(1600, 768), float16], %FunctionVar_0_1: Tensor[(2304, 768), float16], %FunctionVar_0_2: Tensor[(2304), float16], + PartitionedFromPattern="nn.dense_add_", Composite="cutlass.dense_bias") -> Tensor[(1600, 2304), float16] { + %5 = nn.dense(%FunctionVar_0_0, %FunctionVar_0_1, units=2304); + add(%5, %FunctionVar_0_2) + }; + %4(%y_0_i0, %y_0_i1, %y_0_i2) + }; + %1 = %0(%x0, meta[relay.Constant][0], meta[relay.Constant][1]); + %2 = fn(%y_3_i0: Tensor[(600, 32, 64), float16], %y_3_i1: Tensor[(600, 32, 64), float16], + Inline=1, Compiler="cublas", global_symbol="tvmgen_default_cublas_main_3", Primitive=1) -> Tensor[(600, 32, 32), float16] { + %6 = fn (%FunctionVar_0_01: Tensor[(600, 32, 64), float16], %FunctionVar_0_11: Tensor[(600, 32, 64), float16], + PartitionedFromPattern="nn.batch_matmul_", Composite="cublas.batch_matmul") -> Tensor[(600, 32, 32), float16] { + nn.batch_matmul(%FunctionVar_0_01, %FunctionVar_0_11, out_dtype="float16", transpose_b=True) + }; + %6(%y_3_i0, %y_3_i1) + }; + %3 = %2(%x3, meta[relay.Constant][2]); + (%1, %3) + } + """, + "from_string", + None, + metatable, + ) + + +def expected_outlined_mod(): + return tvm.parser.parse( + """ + #[version = "0.0.5"] + def @main(%x0 : Tensor[(1600, 768), float16], %x3 : Tensor[(600, 32, 64), float16]) -> (Tensor[(1600, 2304), float16], Tensor[(600, 32, 32), float16]) { + %1 = @tvmgen_default_cutlass_main_0(%x0, meta[relay.Constant][0], meta[relay.Constant][1]); + %2 = fn(%y_3_i0: Tensor[(600, 32, 64), float16], %y_3_i1: Tensor[(600, 32, 64), float16], + Inline=1, Compiler="cublas", global_symbol="tvmgen_default_cublas_main_3", Primitive=1) -> Tensor[(600, 32, 32), float16] { + %6 = fn (%FunctionVar_0_01: Tensor[(600, 32, 64), float16], %FunctionVar_0_11: Tensor[(600, 32, 64), float16], + PartitionedFromPattern="nn.batch_matmul_", Composite="cublas.batch_matmul") -> Tensor[(600, 32, 32), float16] { + nn.batch_matmul(%FunctionVar_0_01, %FunctionVar_0_11, out_dtype="float16", transpose_b=True) + }; + %6(%y_3_i0, %y_3_i1) + }; + %3 = %2(%x3, meta[relay.Constant][2]); + (%1, %3) + } + + def @tvmgen_default_cutlass_main_0(%y_0_i0: Tensor[(1600, 768), float16], %y_0_i1: Tensor[(2304, 768), float16], %y_0_i2: Tensor[(2304), float16], + Inline=1, Compiler="cutlass", global_symbol="tvmgen_default_cutlass_main_0", Primitive=1) -> Tensor[(1600, 2304), float16] { + %4 = fn (%FunctionVar_0_0: Tensor[(1600, 768), float16], %FunctionVar_0_1: Tensor[(2304, 768), float16], %FunctionVar_0_2: Tensor[(2304), float16], + PartitionedFromPattern="nn.dense_add_", Composite="cutlass.dense_bias") -> Tensor[(1600, 2304), float16] { + %5 = nn.dense(%FunctionVar_0_0, %FunctionVar_0_1, units=2304); + add(%5, %FunctionVar_0_2) + }; + %4(%y_0_i0, %y_0_i1, %y_0_i2) + } + """, + "from_string", + None, + metatable, + ) + + +def expected_extern_mod(): + return tvm.parser.parse( + """ + #[version = "0.0.5"] + def @main(%x0 : Tensor[(1600, 768), float16], %x3 : Tensor[(600, 32, 64), float16]) -> (Tensor[(1600, 2304), float16], Tensor[(600, 32, 32), float16]) { + %1 = call_lowered(@tvmgen_default_cutlass_main_0, (%x0, meta[relay.Constant][0], meta[relay.Constant][1]), metadata=meta[attributes][0]); + %2 = fn(%y_3_i0: Tensor[(600, 32, 64), float16], %y_3_i1: Tensor[(600, 32, 64), float16], + Inline=1, Compiler="cublas", global_symbol="tvmgen_default_cublas_main_3", Primitive=1) -> Tensor[(600, 32, 32), float16] { + %6 = fn (%FunctionVar_0_01: Tensor[(600, 32, 64), float16], %FunctionVar_0_11: Tensor[(600, 32, 64), float16], + PartitionedFromPattern="nn.batch_matmul_", Composite="cublas.batch_matmul") -> Tensor[(600, 32, 32), float16] { + nn.batch_matmul(%FunctionVar_0_01, %FunctionVar_0_11, out_dtype="float16", transpose_b=True) + }; + %6(%y_3_i0, %y_3_i1) + }; + %3 = %2(%x3, meta[relay.Constant][2]); + (%1, %3) + } + + def @tvmgen_default_cutlass_main_0(%y_0_i0: Tensor[(1600, 768), float16], %y_0_i1: Tensor[(2304, 768), float16], %y_0_i2: Tensor[(2304), float16], + Extern=1) -> Tensor[(1600, 2304), float16] { + %4 = fn (%FunctionVar_0_0: Tensor[(1600, 768), float16], %FunctionVar_0_1: Tensor[(2304, 768), float16], %FunctionVar_0_2: Tensor[(2304), float16], + PartitionedFromPattern="nn.dense_add_", Composite="cutlass.dense_bias") -> Tensor[(1600, 2304), float16] { + %5 = nn.dense(%FunctionVar_0_0, %FunctionVar_0_1, units=2304); + add(%5, %FunctionVar_0_2) + }; + %4(%y_0_i0, %y_0_i1, %y_0_i2) + } + """, + "from_string", + None, + metatable, + ) + + +def test_outline_compiler_functions_with_existing_global_symbols(): + actual_outlined_mod = tvm.relay.transform.OutlineCompilerFunctionsWithExistingGlobalSymbols( + "cutlass" + )(inlined_mod()) + tvm.ir.assert_structural_equal(actual_outlined_mod, expected_outlined_mod(), map_free_vars=True) + + +def test_mark_compiler_functions_as_extern(): + actual_extern_mod = tvm.relay.transform.MarkCompilerFunctionsAsExtern("cutlass")( + expected_outlined_mod() + ) + tvm.ir.assert_structural_equal(actual_extern_mod, expected_extern_mod(), map_free_vars=True) + + +if __name__ == "__main__": + tvm.testing.main() From f88b5a83b0886b2edaa4c7248cb3db28e148c77e Mon Sep 17 00:00:00 2001 From: mbs-octoml Date: Thu, 2 Jun 2022 17:17:22 -0700 Subject: [PATCH 2/2] - Andrew's comments --- include/tvm/relay/function.h | 16 +++++++++++----- python/tvm/relay/transform/transform.py | 14 +++++++++----- src/relay/backend/te_compiler.cc | 2 +- src/relay/transforms/compiler_function_utils.cc | 13 ++++++++++++- 4 files changed, 33 insertions(+), 12 deletions(-) diff --git a/include/tvm/relay/function.h b/include/tvm/relay/function.h index 29ce07753398..052d04fe2411 100644 --- a/include/tvm/relay/function.h +++ b/include/tvm/relay/function.h @@ -172,23 +172,29 @@ const FunctionNode* AsOptimizableFunctionNode(const BaseFunc& base_func); namespace attr { /*! - * \brief Mark the function as a primitive function. Should be bound to \p Integer(1). + * \brief Mark the function as representing a sub-graph which is to be lowered or compiled as + * a unit. For example, the function may represent a kernel which TVM will lower to a PrimFunc. + * If present should be bound to \p Integer(1). May be accompanied by "Compiler", see below. + * The function body should be considered opaque by Relay, and many passes simply ignore these + * functions. * * Type: Integer */ constexpr const char* kPrimitive = "Primitive"; /*! - * \brief Mark the function as being 'extern', ie implemented in a runtime::Module. Should be bound - * to \p Integer(1). Typically accompanied by "Primitive". + * \brief Mark the function as externally implemented, ie bound in a runtime::Module within the + * IRModule's "external_mods" attribute. If present should be bound to \p Integer(1). Generally + * the only attribute when present. * * Type: Integer */ constexpr const char* kExtern = "Extern"; /*! - * \brief Indicate the external codegen 'compiler' that should be used for building this function. - * When this is unset or set to "default", the default compilation pipeline will be used. + * \brief Indicates the name of the external codegen 'compiler' that should be used to lower + * or compile the function other than TVM's default lowering pipeline. The name may correspond + * to a TargetKind name. There may be a global function registered under 'relay.ext.{name}'. * * Type: String */ diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index 4ee1f070c541..694dbb45218c 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -1371,8 +1371,10 @@ def SplitArgs(max_function_args): def OutlineCompilerFunctionsWithExistingGlobalSymbols(compiler_filter=""): - """A pass to outline all literal functions in direct call positions which have a "Compiler" - attribute. The functions are bound to unique global vars according to their existing + """Outlines all literal functions in direct call positions which have a "Compiler" + attribute. + + The outlined functions are bound to unique global vars according to their existing "global_symbol" attribute. At most one function with the same global symbol is outlined. If compiler_filter is non-empty only functions with that as their attribute value are @@ -1395,9 +1397,11 @@ def OutlineCompilerFunctionsWithExistingGlobalSymbols(compiler_filter=""): def MarkCompilerFunctionsAsExtern(compiler_filter=""): - """A pass to mark all global functions which have a "Compiler" attribute matching - compiler_filter as 'extern' by replacing all attributes with a single "Extern" attribute, and - rewrite all calls to such functions to use the 'call_lowered' calling convention. + """Marks all global functions which have a "Compiler" attribute matching + compiler_filter as 'extern'. + + The function's attributes are replaced with a single "Extern" attribute, and + all calls to the function are switched to use the 'call_lowered' calling convention. If compiler_filter is non-empty only functions with that as their attribute value are outlined. diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc index 344bfde3d90c..c78f3abd6ecc 100644 --- a/src/relay/backend/te_compiler.cc +++ b/src/relay/backend/te_compiler.cc @@ -176,7 +176,7 @@ class TECompilerImpl : public TECompilerNode { WithFields(GetRef(function_node), function_node->params, function_node->body, function_node->ret_type, function_node->type_params, /* erase attributes */ DictAttrs(Map())); - // Mark function as 'extern' using the "ExternalSymbol" attribute. + // Mark function as 'extern'. function = WithAttr(std::move(function), attr::kExtern, Integer(1)); module->Add(kv2.first, function); } diff --git a/src/relay/transforms/compiler_function_utils.cc b/src/relay/transforms/compiler_function_utils.cc index 9d956003df00..b98d089b346a 100644 --- a/src/relay/transforms/compiler_function_utils.cc +++ b/src/relay/transforms/compiler_function_utils.cc @@ -49,8 +49,9 @@ class Outliner : public MixedModeMutator { if (opt_compiler.defined() && (compiler_filter_.empty() || opt_compiler.value() == compiler_filter_)) { auto function = GetRef(function_node); - ICHECK(FreeVars(function).empty()) << "Function marked with '" << attr::kCompiler + DCHECK(FreeVars(function).empty()) << "Function marked with '" << attr::kCompiler << "' attribute should not have free variables"; + // Ask the cache to supply a unique global var for this function. GlobalVar global_symbol = cache_->GetGlobalSymbol(function); // Depending on the cache's implementation, two structurally equal (but not object equal) // functions may be assigned the same global symbol. If so we'll lift it just once, but @@ -60,6 +61,7 @@ class Outliner : public MixedModeMutator { WithAttr(std::move(function), tvm::attr::kGlobalSymbol, global_symbol->name_hint); mod_->Add(global_symbol, function); } + // Update the call. return WithFields(new_call, global_symbol); } } @@ -67,8 +69,15 @@ class Outliner : public MixedModeMutator { } private: + /*! + * \brief A cached mapping from functions to global variables. Depending on the implementation + * the cache may generate fresh symbols or require the function to already have a "global_symbol" + * attribute, and may share symbols between structurally equal functions. + */ GlobalSymbolCache* cache_; + /*! \brief If non-empty, the "Compiler" attribute value to require on functions to outline. */ std::string compiler_filter_; + /*! \brief Module being rewritten. */ IRModule mod_; }; @@ -102,7 +111,9 @@ class CallRewriter : public MixedModeMutator { } private: + /*! \brief If non-empty, the "Compiler" attribute value to require on functions to outline. */ std::string compiler_filter_; + /*! \brief Module being rewritten. */ IRModule mod_; };