diff --git a/python/tvm/relay/op/contrib/tensorrt.py b/python/tvm/relay/op/contrib/tensorrt.py index 58dac06382c0..a69e2d410529 100644 --- a/python/tvm/relay/op/contrib/tensorrt.py +++ b/python/tvm/relay/op/contrib/tensorrt.py @@ -26,7 +26,7 @@ from tvm.relay import transform from tvm.relay.build_module import bind_params_by_name from tvm.relay.dataflow_pattern import is_op, wildcard, is_constant, is_tuple, is_tuple_get_item -from tvm.relay.expr import Call, Constant, GlobalVar, TupleGetItem +from tvm.relay.expr import Call, Constant, TupleGetItem from tvm.relay.expr_functor import ExprMutator, ExprVisitor from tvm.relay.op.contrib.register import register_pattern_table @@ -864,7 +864,11 @@ def pattern_table() -> List[ binary_op_pattern_with_const("nn.dense"), make_predicate(dense_checker), ), - ("tensorrt.bias_add", binary_op_pattern("nn.bias_add"), make_predicate(bias_add_checker)), + ( + "tensorrt.nn.bias_add", + binary_op_pattern("nn.bias_add"), + make_predicate(bias_add_checker), + ), ( "tensorrt.nn.batch_matmul", binary_op_pattern("nn.batch_matmul"), @@ -1062,7 +1066,6 @@ def is_valid_subgraph(params: List[relay.expr.Var], body: relay.expr.Expr) -> bo for var in params: # In implicit batch mode, all inputs must have same batch size # TODO: (codeislife99) : Fix different dynamic batch size inputs - if isinstance(var.checked_type, relay.TupleType): for tupe_type in var.checked_type.fields: # Scalar inputs not allowed @@ -1079,64 +1082,32 @@ def is_valid_subgraph(params: List[relay.expr.Var], body: relay.expr.Expr) -> bo return False if not isinstance(var.checked_type.shape[0], tvm.tir.expr.Any): input_batch_sizes.append(int(var.checked_type.shape[0])) + if len(input_batch_sizes) > 1 and len(set(input_batch_sizes)) != 1: - logger.info("tensorrt: inputs have different batch sizes") + logger.info("tensorrt: inputs have different batch sizes: %s", input_batch_sizes) return False + if get_tensorrt_remove_no_mac_subgraphs(): - return IsComputeIntensiveGraph().is_graph_compute_intensive(body) + if not IsComputeIntensiveGraph().is_graph_compute_intensive(body): + logger.info("tensorrt: not a compute-intensize sub-graph") + return False + return True def prune_tensorrt_subgraphs(mod: tvm.IRModule) -> tvm.IRModule: """ - Removes invalid subgraphs and those with no multiply-accumulates (if remove_no_max_subgraphs - is set). - """ - - class SubgraphRemover(ExprMutator): - """ - Reverts subgraphs in subgraphs_to_remove back to TVM instead of using an external codegen. - """ - - def __init__( - self, subgraphs_to_remove: List[str], mod: tvm.IRModule, new_mod: tvm.IRModule - ) -> None: - ExprMutator.__init__(self) - self.subgraphs_to_remove = subgraphs_to_remove - self.mod = mod - self.new_mod = new_mod - - def visit_call(self, call: relay.expr.Call) -> relay.expr.Expr: - if isinstance(call.op, GlobalVar): - name = call.op.name_hint - if name in self.subgraphs_to_remove: - # "Inline" the subgraph back into new main function. - func = self.mod[name] - var_map = {} - for arg, param in zip(call.args, func.params): - var_map[param] = super().visit(arg) - new_body = relay.bind(func.body, var_map) - return new_body - if name != "main": - args = [] - for arg in call.args: - args.append(super().visit(arg)) - return call.op(*args) - return super().visit_call(call) - - subgraphs_to_remove: List[str] = [] - # Remove invalid subgraphs - for subgraph in mod.get_global_vars(): - name = subgraph.name_hint - if not mod[name].attrs or mod[name].attrs["Compiler"] != "tensorrt": - continue - if not is_valid_subgraph(mod[name].params, mod[name].body): - subgraphs_to_remove.append(name) - # Create new pruned module - new_mod = tvm.IRModule(mod.functions, mod.type_definitions) - new_mod["main"] = SubgraphRemover(subgraphs_to_remove, mod, new_mod).visit(mod["main"]) - new_mod = transform.RemoveUnusedFunctions()(new_mod) - return new_mod + Un-partition those partitions which: + - have no multiply-accumulates (if remove_no_mac_subgraphs is True) + - can't actually be supported by TensorRT now that we see the whole partition.""" + global_vars_to_inline = [ + gv + for gv in mod.get_global_vars() + if mod[gv].attrs + and mod[gv].attrs["Compiler"] == "tensorrt" + and not is_valid_subgraph(mod[gv].params, mod[gv].body) + ] + return relay.transform.InlineCompilerFunctionsBoundTo(global_vars_to_inline)(mod) class RemoveDropout(ExprMutator): diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index 694dbb45218c..979664f72ca3 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -1420,3 +1420,25 @@ def MarkCompilerFunctionsAsExtern(compiler_filter=""): The pass. """ return _ffi_api.MarkCompilerFunctionsAsExtern(compiler_filter) + + +def InlineCompilerFunctionsBoundTo(global_vars): + """Inlines all global functions bound to a global var in global_vars. + + Both the global "Compiler" attributed function, and any calls to "Composite" functions it its + body are inlined. + + This pass may be useful for external codegen which needs to undo partitioning based on + properties of the entire partition. + + Parameters + ---------- + global_vars : Array[tvm.relay.GlobalVar] + The global vars of all 'Compiler' functions to inline. + + Returns + ------- + ret : tvm.transform.Pass + The pass. + """ + return _ffi_api.InlineCompilerFunctionsBoundTo(global_vars) diff --git a/src/relay/transforms/compiler_function_utils.cc b/src/relay/transforms/compiler_function_utils.cc index 3df07e4c57f5..1b0f002f1def 100644 --- a/src/relay/transforms/compiler_function_utils.cc +++ b/src/relay/transforms/compiler_function_utils.cc @@ -27,12 +27,28 @@ #include "../op/call/call.h" #include "tvm/relay/analysis.h" #include "tvm/relay/expr_functor.h" +#include "tvm/relay/transform.h" namespace tvm { namespace relay { namespace transforms { namespace { +/*! + * \brief Returns the \p FunctionNode of if \p expr if it is a "Compiler" function which should + * be processed by a pass using \p compiler_filter. Otherwise returns null. + */ +const FunctionNode* AsFunctionNode(const Expr& expr, const std::string& compiler_filter) { + if (const auto* function_node = expr.as()) { + Optional opt_compiler = function_node->GetAttr(attr::kCompiler); + if (opt_compiler.defined() && + (compiler_filter.empty() || opt_compiler.value() == compiler_filter)) { + return function_node; + } + } + return nullptr; +} + /*! * \brief Rewrite calls to inlined "Compiler" functions to global functions. The given * module will be extended with the newly outlined functions. @@ -44,26 +60,22 @@ class Outliner : public MixedModeMutator { Expr Rewrite_(const CallNode* pre, const Expr& post) final { Call new_call = Downcast(post); - if (const auto* function_node = new_call->op.as()) { - Optional opt_compiler = function_node->GetAttr(attr::kCompiler); - if (opt_compiler.defined() && - (compiler_filter_.empty() || opt_compiler.value() == compiler_filter_)) { - auto function = GetRef(function_node); - DCHECK(FreeVars(function).empty()) << "Function marked with '" << attr::kCompiler - << "' attribute should not have free variables"; - // Ask the cache to supply a unique global var for this function. - GlobalVar global_symbol = cache_->GetGlobalSymbol(function); - // Depending on the cache's implementation, two structurally equal (but not object equal) - // functions may be assigned the same global symbol. If so we'll lift it just once, but - // rewrite all the calls. - if (!mod_->ContainGlobalVar(global_symbol->name_hint)) { - function = - WithAttr(std::move(function), tvm::attr::kGlobalSymbol, global_symbol->name_hint); - mod_->Add(global_symbol, function); - } - // Update the call. - return WithFields(new_call, global_symbol); + if (const auto* function_node = AsFunctionNode(new_call->op, compiler_filter_)) { + auto function = GetRef(function_node); + DCHECK(FreeVars(function).empty()) << "Function marked with '" << attr::kCompiler + << "' attribute should not have free variables"; + // Ask the cache to supply a unique global var for this function. + GlobalVar global_symbol = cache_->GetGlobalSymbol(function); + // Depending on the cache's implementation, two structurally equal (but not object + // equal) functions may be assigned the same global symbol. If so we'll lift it just + // once, but rewrite all the calls. + if (!mod_->ContainGlobalVar(global_symbol->name_hint)) { + function = + WithAttr(std::move(function), tvm::attr::kGlobalSymbol, global_symbol->name_hint); + mod_->Add(global_symbol, function); } + // Update the call. + return WithFields(new_call, global_symbol); } return post; } @@ -71,8 +83,8 @@ class Outliner : public MixedModeMutator { private: /*! * \brief A cached mapping from functions to global variables. Depending on the implementation - * the cache may generate fresh symbols or require the function to already have a "global_symbol" - * attribute, and may share symbols between structurally equal functions. + * the cache may generate fresh symbols or require the function to already have a + * "global_symbol" attribute, and may share symbols between structurally equal functions. */ GlobalSymbolCache* cache_; /*! \brief If non-empty, the "Compiler" attribute value to require on functions to outline. */ @@ -81,6 +93,72 @@ class Outliner : public MixedModeMutator { IRModule mod_; }; +/*! + * \brief Inline immediate calls to "Composite" functions. + */ +class InnerInliner : public MixedModeMutator { + public: + InnerInliner() = default; + + private: + using MixedModeMutator::Rewrite_; + + Expr Rewrite_(const CallNode* pre, const Expr& post) final { + Call new_call = Downcast(post); + if (const auto* function_node = new_call->op.as()) { + ICHECK(function_node->GetAttr(attr::kComposite).defined()); + ICHECK_EQ(function_node->params.size(), new_call->args.size()); + Map subst; + for (size_t i = 0; i < new_call->args.size(); ++i) { + subst.Set(function_node->params[i], new_call->args[i]); + } + return Bind(function_node->body, subst); + } + return post; + } +}; + +/*! + * \brief Inline calls to global "Compiler" functions with global var in \p global_vars. + * Both the 'outer' "Compiler" function and any 'inner' "Composite" functions in its body + * are inlined. + */ +class OuterInliner : public MixedModeMutator { + public: + OuterInliner(IRModule mod, Array global_vars_) + : mod_(std::move(mod)), global_vars_(std::move(global_vars_)) {} + + private: + using MixedModeMutator::Rewrite_; + + Expr Rewrite_(const CallNode* pre, const Expr& post) final { + Call new_call = Downcast(post); + if (const auto* global_var_node = new_call->op.as()) { + auto global_var = GetRef(global_var_node); + if (std::find(global_vars_.begin(), global_vars_.end(), global_var) != global_vars_.end()) { + BaseFunc base_func = mod_->Lookup(global_var); + const auto* function_node = base_func.as(); + ICHECK(function_node); + ICHECK(function_node->GetAttr(attr::kCompiler).defined()); + ICHECK_EQ(function_node->params.size(), new_call->args.size()); + Map subst; + for (size_t i = 0; i < new_call->args.size(); ++i) { + subst.Set(function_node->params[i], new_call->args[i]); + } + Expr new_body = InnerInliner().VisitExpr(function_node->body); + return Bind(new_body, subst); + } + } + return post; + } + + private: + /*! \brief Original module we are processing. */ + IRModule mod_; + /*! \brief Global vars of functions to inline. */ + Array global_vars_; +}; + } // namespace GlobalSymbolCache::~GlobalSymbolCache() = default; @@ -106,10 +184,10 @@ transform::Pass OutlineCompilerFunctions(std::shared_ptr cach runtime::TypedPackedFunc pass_func = [cache = std::move(cache), compiler_filter = std::move(compiler_filter)]( IRModule mod, transform::PassContext ctx) { - IRModule output_mod = GetRef(mod.CopyOnWrite()); + VLOG(1) << "OutlineCompilerFunctions input:" << std::endl << PrettyPrint(mod); + IRModule output_mod = mod->ShallowCopy(); for (const auto& kv : mod->functions) { - const FunctionNode* function_node = AsOptimizableFunctionNode(kv.second); - if (function_node) { + if (const auto* function_node = AsOptimizableFunctionNode(kv.second)) { Expr new_body = Outliner(cache.get(), compiler_filter, output_mod).VisitExpr(function_node->body); Function new_function = @@ -117,6 +195,7 @@ transform::Pass OutlineCompilerFunctions(std::shared_ptr cach output_mod->Add(kv.first, new_function); } } + VLOG(1) << "OutlineCompilerFunctions result:" << std::endl << PrettyPrint(output_mod); return output_mod; }; @@ -132,31 +211,57 @@ transform::Pass OutlineCompilerFunctionsWithExistingGlobalSymbols(std::string co transform::Pass MarkCompilerFunctionsAsExtern(std::string compiler_filter) { runtime::TypedPackedFunc pass_func = [compiler_filter = std::move(compiler_filter)](IRModule mod, transform::PassContext ctx) { + VLOG(1) << "MarkCompilerFunctionsAsExtern input:" << std::endl << PrettyPrint(mod); IRModule output_mod = mod->ShallowCopy(); for (const auto& kv : mod->functions) { - if (const auto* function_node = kv.second.as()) { - Optional opt_compiler = function_node->GetAttr(attr::kCompiler); - if (opt_compiler.defined() && - (compiler_filter.empty() || opt_compiler.value() == compiler_filter)) { - auto new_function = WithFields( - GetRef(function_node), function_node->params, function_node->body, - function_node->ret_type, function_node->type_params, - /* erase attributes */ DictAttrs(Map())); - new_function = WithAttr(std::move(new_function), attr::kExtern, Integer(1)); - output_mod->Update(kv.first, new_function); - } + if (const auto* function_node = AsFunctionNode(kv.second, compiler_filter)) { + auto new_function = + WithFields(GetRef(function_node), function_node->params, + function_node->body, function_node->ret_type, function_node->type_params, + /* erase attributes */ DictAttrs(Map())); + new_function = WithAttr(std::move(new_function), attr::kExtern, Integer(1)); + output_mod->Update(kv.first, new_function); } } + VLOG(1) << "MarkCompilerFunctionsAsExtern result:" << std::endl << PrettyPrint(output_mod); return output_mod; }; return tvm::transform::CreateModulePass(pass_func, 0, "MarkCompilerFunctionsAsExtern", {}); } +transform::Pass InlineCompilerFunctionsBoundTo(Array global_vars) { + runtime::TypedPackedFunc pass_func = + [global_vars = std::move(global_vars)](IRModule mod, transform::PassContext ctx) { + VLOG(1) << "InlineCompilerFunctionsBoundTo with global_vars: " << PrettyPrint(global_vars); + if (global_vars.empty()) { + return mod; + } + VLOG(1) << "InlineCompilerFunctions input:" << std::endl << PrettyPrint(mod); + IRModule output_mod = mod->ShallowCopy(); + for (const auto& kv : mod->functions) { + if (std::find(global_vars.begin(), global_vars.end(), kv.first) != global_vars.end()) { + output_mod->Remove(kv.first); + } else if (const auto* function_node = AsOptimizableFunctionNode(kv.second)) { + Expr new_body = OuterInliner(mod, global_vars).VisitExpr(function_node->body); + Function new_function = + WithFields(GetRef(function_node), /*opt_params=*/{}, new_body); + output_mod->Add(kv.first, new_function); + } + } + VLOG(1) << "InlineCompilerFunctionsBoundTo result:" << std::endl << PrettyPrint(output_mod); + return output_mod; + }; + + return tvm::transform::CreateModulePass(pass_func, 0, "InlineCompilerFunctionsBoundTo", {}); +} + TVM_REGISTER_GLOBAL("relay._transform.OutlineCompilerFunctionsWithExistingGlobalSymbols") .set_body_typed(OutlineCompilerFunctionsWithExistingGlobalSymbols); TVM_REGISTER_GLOBAL("relay._transform.MarkCompilerFunctionsAsExtern") .set_body_typed(MarkCompilerFunctionsAsExtern); +TVM_REGISTER_GLOBAL("relay._transform.InlineCompilerFunctionsBoundTo") + .set_body_typed(InlineCompilerFunctionsBoundTo); } // namespace transforms } // namespace relay diff --git a/src/relay/transforms/compiler_function_utils.h b/src/relay/transforms/compiler_function_utils.h index 9d1dcd9f21a2..6664594fc0a0 100644 --- a/src/relay/transforms/compiler_function_utils.h +++ b/src/relay/transforms/compiler_function_utils.h @@ -22,10 +22,10 @@ * \brief Helper passes for working with functions with the "Compiler" attribute. * * Those wishing to use the "RelayToTIR" custom pass machinery to do IRModule-at-a-time external - * codegen may find the following two helper passes useful: + * codegen may find the following helpers useful: * - * - \p OutlineCompilerFunctionsWithExistingGlobalSymbols will lift inline functions with a - * matching "Compiler" attribute to be global functions, using the "global_symbol" attribute + * - The \p OutlineCompilerFunctionsWithExistingGlobalSymbols pass will lift inline functions with + * a matching "Compiler" attribute to be global functions, using the "global_symbol" attribute * already assigned. Can be used before custom lowering. * * Note that ideally "Compiler" attributed functions would be made global functions as early as @@ -36,15 +36,22 @@ * * See also OutlineCompilerFunctionsMutator in src/relay/backend/contrib/ethosu/codegen.cc. * - * - (\p OutlineCompilerFunctions is a more general version of the above which can use a custom - * cache to both allocate "global_symbol" names and ensure two strucurally equal functions are - * assigned the same name, and thus lowered only once. This is used by Collage when preparing - * the optimally partitioned IRModule). + * - (The \p OutlineCompilerFunctions pass is a more general version of the above which can use + * a custom cache to both allocate "global_symbol" names and ensure two structurally equal + * functions are assigned the same name, and thus lowered only once. This is used by Collage + * when preparing the optimally partitioned IRModule). * - * - \p MarkCompilerFunctionsAsExtern will replace global functions with a matching "Compiler" - * attribute with the same function with just an "Extern" attribute, signalling the function - * has been dealt with. However calls to such functions will be left unchanged. Can be used - * after lowering to cleanup the IRModule. + * - The \p MarkCompilerFunctionsAsExtern pass will update the attributes of global functions + * with a matching "Compiler" attribute to have just the "Extern" attribute. That will signal + * the function has been dealt with. However calls to such functions will be left unchanged. + * Can be used after lowering to cleanup the IRModule. + * + * - The \p InlineCompilerFunctions pass can selectively inline global functions with a matching + * "Compiler" attribute who's name appears in the given set. Obviously it's more sensible to + * not create that function in the first place, however some external codegen have rules to + * accept or reject partitionings based on the overall partitioned function body. This pass + * can be used do the legwork, and will take care to not only inline the outer "Compiler" + * annotated funcition, but also any "Composite" annotated functions in its body. */ #ifndef TVM_RELAY_TRANSFORMS_COMPILER_FUNCTION_UTILS_H_ @@ -126,6 +133,16 @@ transform::Pass OutlineCompilerFunctionsWithExistingGlobalSymbols(std::string co */ transform::Pass MarkCompilerFunctionsAsExtern(std::string compiler_filter = ""); +/*! + * \brief A pass to inline all global "Compiler" functions which are bound to a global var + * in \p global_vars. Both the global function and any calls to "Composite" functions it its body + * are inlined. + * + * This pass may be useful for external codegen which needs to undo partitioning based on + * properties of the entire partition. + */ +transform::Pass InlineCompilerFunctionsBoundTo(Array global_vars); + } // namespace transforms } // namespace relay } // namespace tvm diff --git a/tests/python/relay/transform/test_compiler_function_utils.py b/tests/python/relay/transform/test_compiler_function_utils.py index b9eb11547595..66abeff8ab29 100644 --- a/tests/python/relay/transform/test_compiler_function_utils.py +++ b/tests/python/relay/transform/test_compiler_function_utils.py @@ -42,7 +42,7 @@ def make_consts(dtype, shapes): } -def inlined_mod(): +def original_mod(): return tvm.parser.parse( """ #[version = "0.0.5"] @@ -143,10 +143,35 @@ def @tvmgen_default_cutlass_main_0(%y_0_i0: Tensor[(1600, 768), float16], %y_0_i ) +def expected_inlined_mod(): + return tvm.parser.parse( + """ + #[version = "0.0.5"] + def @main(%x0 : Tensor[(1600, 768), float16], %x3 : Tensor[(600, 32, 64), float16]) -> (Tensor[(1600, 2304), float16], Tensor[(600, 32, 32), float16]) { + %0 = nn.dense(%x0, meta[relay.Constant][0], units=2304); + %1 = add(%0, meta[relay.Constant][1]); + %2 = fn(%y_3_i0: Tensor[(600, 32, 64), float16], %y_3_i1: Tensor[(600, 32, 64), float16], + Inline=1, Compiler="cublas", global_symbol="tvmgen_default_cublas_main_3", Primitive=1) -> Tensor[(600, 32, 32), float16] { + %6 = fn (%FunctionVar_0_01: Tensor[(600, 32, 64), float16], %FunctionVar_0_11: Tensor[(600, 32, 64), float16], + PartitionedFromPattern="nn.batch_matmul_", Composite="cublas.batch_matmul") -> Tensor[(600, 32, 32), float16] { + nn.batch_matmul(%FunctionVar_0_01, %FunctionVar_0_11, out_dtype="float16", transpose_b=True) + }; + %6(%y_3_i0, %y_3_i1) + }; + %3 = %2(%x3, meta[relay.Constant][2]); + (%1, %3) + } + """, + "from_string", + None, + metatable, + ) + + def test_outline_compiler_functions_with_existing_global_symbols(): actual_outlined_mod = tvm.relay.transform.OutlineCompilerFunctionsWithExistingGlobalSymbols( "cutlass" - )(inlined_mod()) + )(original_mod()) tvm.ir.assert_structural_equal(actual_outlined_mod, expected_outlined_mod(), map_free_vars=True) @@ -157,5 +182,12 @@ def test_mark_compiler_functions_as_extern(): tvm.ir.assert_structural_equal(actual_extern_mod, expected_extern_mod(), map_free_vars=True) +def test_inline_compiler_functions(): + mod = expected_outlined_mod() + gv = mod.get_global_var("tvmgen_default_cutlass_main_0") + actual_inlined_mod = tvm.relay.transform.InlineCompilerFunctionsBoundTo([gv])(mod) + tvm.ir.assert_structural_equal(actual_inlined_mod, expected_inlined_mod(), map_free_vars=True) + + if __name__ == "__main__": tvm.testing.main()