From 6cc6e0cc99f5d20c1b5585d9c46e430bba9b16ff Mon Sep 17 00:00:00 2001 From: mbs-octoml Date: Mon, 27 Jun 2022 12:04:38 -0700 Subject: [PATCH 1/3] [BYOC] InlineCompilerFunctions helper pass The TensorRT BYOC integration needs to 'undo' partitionings in some situations. Add an InlineCompilerFunctions pass to make that robust. In particular, it must undo both the 'partitioning' (ie separating out the "Compiler" function) and any 'compositing' (ie separating out small sub-graphs as "Composite" functions). Fix misspelled nn.bias_add while there. Note that the current implementation is broken but untested in CI. I have all the tests fixed in a follow-up PR. --- python/tvm/relay/op/contrib/tensorrt.py | 75 +++----- python/tvm/relay/transform/transform.py | 22 +++ .../transforms/compiler_function_utils.cc | 175 ++++++++++++++---- .../transforms/compiler_function_utils.h | 39 ++-- .../transform/test_compiler_function_utils.py | 36 +++- 5 files changed, 247 insertions(+), 100 deletions(-) diff --git a/python/tvm/relay/op/contrib/tensorrt.py b/python/tvm/relay/op/contrib/tensorrt.py index 58dac06382c0..55aef9a5f986 100644 --- a/python/tvm/relay/op/contrib/tensorrt.py +++ b/python/tvm/relay/op/contrib/tensorrt.py @@ -864,7 +864,11 @@ def pattern_table() -> List[ binary_op_pattern_with_const("nn.dense"), make_predicate(dense_checker), ), - ("tensorrt.bias_add", binary_op_pattern("nn.bias_add"), make_predicate(bias_add_checker)), + ( + "tensorrt.nn.bias_add", + binary_op_pattern("nn.bias_add"), + make_predicate(bias_add_checker), + ), ( "tensorrt.nn.batch_matmul", binary_op_pattern("nn.batch_matmul"), @@ -1062,7 +1066,6 @@ def is_valid_subgraph(params: List[relay.expr.Var], body: relay.expr.Expr) -> bo for var in params: # In implicit batch mode, all inputs must have same batch size # TODO: (codeislife99) : Fix different dynamic batch size inputs - if isinstance(var.checked_type, relay.TupleType): for tupe_type in var.checked_type.fields: # Scalar inputs not allowed @@ -1079,64 +1082,32 @@ def is_valid_subgraph(params: List[relay.expr.Var], body: relay.expr.Expr) -> bo return False if not isinstance(var.checked_type.shape[0], tvm.tir.expr.Any): input_batch_sizes.append(int(var.checked_type.shape[0])) + if len(input_batch_sizes) > 1 and len(set(input_batch_sizes)) != 1: - logger.info("tensorrt: inputs have different batch sizes") + logger.info("tensorrt: inputs have different batch sizes: %s", input_batch_sizes) return False + if get_tensorrt_remove_no_mac_subgraphs(): - return IsComputeIntensiveGraph().is_graph_compute_intensive(body) + if not IsComputeIntensiveGraph().is_graph_compute_intensive(body): + logger.info("tensorrt: not a compute-intensize sub-graph") + return False + return True def prune_tensorrt_subgraphs(mod: tvm.IRModule) -> tvm.IRModule: """ - Removes invalid subgraphs and those with no multiply-accumulates (if remove_no_max_subgraphs - is set). - """ - - class SubgraphRemover(ExprMutator): - """ - Reverts subgraphs in subgraphs_to_remove back to TVM instead of using an external codegen. - """ - - def __init__( - self, subgraphs_to_remove: List[str], mod: tvm.IRModule, new_mod: tvm.IRModule - ) -> None: - ExprMutator.__init__(self) - self.subgraphs_to_remove = subgraphs_to_remove - self.mod = mod - self.new_mod = new_mod - - def visit_call(self, call: relay.expr.Call) -> relay.expr.Expr: - if isinstance(call.op, GlobalVar): - name = call.op.name_hint - if name in self.subgraphs_to_remove: - # "Inline" the subgraph back into new main function. - func = self.mod[name] - var_map = {} - for arg, param in zip(call.args, func.params): - var_map[param] = super().visit(arg) - new_body = relay.bind(func.body, var_map) - return new_body - if name != "main": - args = [] - for arg in call.args: - args.append(super().visit(arg)) - return call.op(*args) - return super().visit_call(call) - - subgraphs_to_remove: List[str] = [] - # Remove invalid subgraphs - for subgraph in mod.get_global_vars(): - name = subgraph.name_hint - if not mod[name].attrs or mod[name].attrs["Compiler"] != "tensorrt": - continue - if not is_valid_subgraph(mod[name].params, mod[name].body): - subgraphs_to_remove.append(name) - # Create new pruned module - new_mod = tvm.IRModule(mod.functions, mod.type_definitions) - new_mod["main"] = SubgraphRemover(subgraphs_to_remove, mod, new_mod).visit(mod["main"]) - new_mod = transform.RemoveUnusedFunctions()(new_mod) - return new_mod + Un-partition those partitions which: + - have no multiply-accumulates (if remove_no_mac_subgraphs is True) + - can't actually be supported by TensorRT now that we see the whole partition.""" + global_vars_to_inline = [ + gv + for gv in mod.get_global_vars() + if mod[gv].attrs + and mod[gv].attrs["Compiler"] == "tensorrt" + and not is_valid_subgraph(mod[gv].params, mod[gv].body) + ] + return relay.transform.InlineCompilerFunctions(global_vars_to_inline)(mod) class RemoveDropout(ExprMutator): diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index 694dbb45218c..663a9291b2f8 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -1420,3 +1420,25 @@ def MarkCompilerFunctionsAsExtern(compiler_filter=""): The pass. """ return _ffi_api.MarkCompilerFunctionsAsExtern(compiler_filter) + + +def InlineCompilerFunctions(global_vars): + """Inlines all global functions bound to a global var in global_vars. + + Both the global "Compiler" attributed function, and any "Composite" functions it its body are + inlined. + + This pass may be useful for external codegen which needs to undo partitioning based on + properties of the entire partition. + + Parameters + ---------- + global_vars : Array[tvm.relay.GlobalVar] + The global vars of all 'Compiler' functions to inline. + + Returns + ------- + ret : tvm.transform.Pass + The pass. + """ + return _ffi_api.InlineCompilerFunctions(global_vars) diff --git a/src/relay/transforms/compiler_function_utils.cc b/src/relay/transforms/compiler_function_utils.cc index 3df07e4c57f5..4fc8a01ed75a 100644 --- a/src/relay/transforms/compiler_function_utils.cc +++ b/src/relay/transforms/compiler_function_utils.cc @@ -27,12 +27,28 @@ #include "../op/call/call.h" #include "tvm/relay/analysis.h" #include "tvm/relay/expr_functor.h" +#include "tvm/relay/transform.h" namespace tvm { namespace relay { namespace transforms { namespace { +/*! + * \brief Returns the \p FunctionNode of if \p expr if it is a "Compiler" function which should + * be processed by a pass using \p compiler_filter. Otherwise returns null. + */ +const FunctionNode* AsFunctionNode(const Expr& expr, const std::string& compiler_filter) { + if (const auto* function_node = expr.as()) { + Optional opt_compiler = function_node->GetAttr(attr::kCompiler); + if (opt_compiler.defined() && + (compiler_filter.empty() || opt_compiler.value() == compiler_filter)) { + return function_node; + } + } + return nullptr; +} + /*! * \brief Rewrite calls to inlined "Compiler" functions to global functions. The given * module will be extended with the newly outlined functions. @@ -44,26 +60,22 @@ class Outliner : public MixedModeMutator { Expr Rewrite_(const CallNode* pre, const Expr& post) final { Call new_call = Downcast(post); - if (const auto* function_node = new_call->op.as()) { - Optional opt_compiler = function_node->GetAttr(attr::kCompiler); - if (opt_compiler.defined() && - (compiler_filter_.empty() || opt_compiler.value() == compiler_filter_)) { - auto function = GetRef(function_node); - DCHECK(FreeVars(function).empty()) << "Function marked with '" << attr::kCompiler - << "' attribute should not have free variables"; - // Ask the cache to supply a unique global var for this function. - GlobalVar global_symbol = cache_->GetGlobalSymbol(function); - // Depending on the cache's implementation, two structurally equal (but not object equal) - // functions may be assigned the same global symbol. If so we'll lift it just once, but - // rewrite all the calls. - if (!mod_->ContainGlobalVar(global_symbol->name_hint)) { - function = - WithAttr(std::move(function), tvm::attr::kGlobalSymbol, global_symbol->name_hint); - mod_->Add(global_symbol, function); - } - // Update the call. - return WithFields(new_call, global_symbol); + if (const auto* function_node = AsFunctionNode(new_call->op, compiler_filter_)) { + auto function = GetRef(function_node); + DCHECK(FreeVars(function).empty()) << "Function marked with '" << attr::kCompiler + << "' attribute should not have free variables"; + // Ask the cache to supply a unique global var for this function. + GlobalVar global_symbol = cache_->GetGlobalSymbol(function); + // Depending on the cache's implementation, two structurally equal (but not object + // equal) functions may be assigned the same global symbol. If so we'll lift it just + // once, but rewrite all the calls. + if (!mod_->ContainGlobalVar(global_symbol->name_hint)) { + function = + WithAttr(std::move(function), tvm::attr::kGlobalSymbol, global_symbol->name_hint); + mod_->Add(global_symbol, function); } + // Update the call. + return WithFields(new_call, global_symbol); } return post; } @@ -71,8 +83,8 @@ class Outliner : public MixedModeMutator { private: /*! * \brief A cached mapping from functions to global variables. Depending on the implementation - * the cache may generate fresh symbols or require the function to already have a "global_symbol" - * attribute, and may share symbols between structurally equal functions. + * the cache may generate fresh symbols or require the function to already have a + * "global_symbol" attribute, and may share symbols between structurally equal functions. */ GlobalSymbolCache* cache_; /*! \brief If non-empty, the "Compiler" attribute value to require on functions to outline. */ @@ -81,6 +93,72 @@ class Outliner : public MixedModeMutator { IRModule mod_; }; +/*! + * \brief Inline immediate calls to "Composite" functions. + */ +class InnerInliner : public MixedModeMutator { + public: + InnerInliner() = default; + + private: + using MixedModeMutator::Rewrite_; + + Expr Rewrite_(const CallNode* pre, const Expr& post) final { + Call new_call = Downcast(post); + if (const auto* function_node = new_call->op.as()) { + ICHECK(function_node->GetAttr(attr::kComposite).defined()); + ICHECK_EQ(function_node->params.size(), new_call->args.size()); + Map subst; + for (size_t i = 0; i < new_call->args.size(); ++i) { + subst.Set(function_node->params[i], new_call->args[i]); + } + return Bind(function_node->body, subst); + } + return post; + } +}; + +/*! + * \brief Inline calls to global "Compiler" functions with global var in \p global_vars. + * Both the 'outer' "Compiler" function and any 'inner' "Composite" functions in its body + * are inlined. + */ +class OuterInliner : public MixedModeMutator { + public: + OuterInliner(IRModule mod, Array global_vars_) + : mod_(std::move(mod)), global_vars_(std::move(global_vars_)) {} + + private: + using MixedModeMutator::Rewrite_; + + Expr Rewrite_(const CallNode* pre, const Expr& post) final { + Call new_call = Downcast(post); + if (const auto* global_var_node = new_call->op.as()) { + auto global_var = GetRef(global_var_node); + if (std::find(global_vars_.begin(), global_vars_.end(), global_var) != global_vars_.end()) { + BaseFunc base_func = mod_->Lookup(global_var); + const auto* function_node = base_func.as(); + ICHECK(function_node); + ICHECK(function_node->GetAttr(attr::kCompiler).defined()); + ICHECK_EQ(function_node->params.size(), new_call->args.size()); + Map subst; + for (size_t i = 0; i < new_call->args.size(); ++i) { + subst.Set(function_node->params[i], new_call->args[i]); + } + Expr new_body = InnerInliner().VisitExpr(function_node->body); + return Bind(new_body, subst); + } + } + return post; + } + + private: + /*! \brief Original module we are processing. */ + IRModule mod_; + /*! \brief Global vars of functions to inline. */ + Array global_vars_; +}; + } // namespace GlobalSymbolCache::~GlobalSymbolCache() = default; @@ -106,10 +184,10 @@ transform::Pass OutlineCompilerFunctions(std::shared_ptr cach runtime::TypedPackedFunc pass_func = [cache = std::move(cache), compiler_filter = std::move(compiler_filter)]( IRModule mod, transform::PassContext ctx) { - IRModule output_mod = GetRef(mod.CopyOnWrite()); + VLOG(1) << "OutlineCompilerFunctions input:" << std::endl << PrettyPrint(mod); + IRModule output_mod = mod->ShallowCopy(); for (const auto& kv : mod->functions) { - const FunctionNode* function_node = AsOptimizableFunctionNode(kv.second); - if (function_node) { + if (const auto* function_node = AsOptimizableFunctionNode(kv.second)) { Expr new_body = Outliner(cache.get(), compiler_filter, output_mod).VisitExpr(function_node->body); Function new_function = @@ -117,6 +195,7 @@ transform::Pass OutlineCompilerFunctions(std::shared_ptr cach output_mod->Add(kv.first, new_function); } } + VLOG(1) << "OutlineCompilerFunctions result:" << std::endl << PrettyPrint(output_mod); return output_mod; }; @@ -132,31 +211,57 @@ transform::Pass OutlineCompilerFunctionsWithExistingGlobalSymbols(std::string co transform::Pass MarkCompilerFunctionsAsExtern(std::string compiler_filter) { runtime::TypedPackedFunc pass_func = [compiler_filter = std::move(compiler_filter)](IRModule mod, transform::PassContext ctx) { + VLOG(1) << "MarkCompilerFunctionsAsExtern input:" << std::endl << PrettyPrint(mod); IRModule output_mod = mod->ShallowCopy(); for (const auto& kv : mod->functions) { - if (const auto* function_node = kv.second.as()) { - Optional opt_compiler = function_node->GetAttr(attr::kCompiler); - if (opt_compiler.defined() && - (compiler_filter.empty() || opt_compiler.value() == compiler_filter)) { - auto new_function = WithFields( - GetRef(function_node), function_node->params, function_node->body, - function_node->ret_type, function_node->type_params, - /* erase attributes */ DictAttrs(Map())); - new_function = WithAttr(std::move(new_function), attr::kExtern, Integer(1)); - output_mod->Update(kv.first, new_function); - } + if (const auto* function_node = AsFunctionNode(kv.second, compiler_filter)) { + auto new_function = + WithFields(GetRef(function_node), function_node->params, + function_node->body, function_node->ret_type, function_node->type_params, + /* erase attributes */ DictAttrs(Map())); + new_function = WithAttr(std::move(new_function), attr::kExtern, Integer(1)); + output_mod->Update(kv.first, new_function); } } + VLOG(1) << "MarkCompilerFunctionsAsExtern result:" << std::endl << PrettyPrint(output_mod); return output_mod; }; return tvm::transform::CreateModulePass(pass_func, 0, "MarkCompilerFunctionsAsExtern", {}); } +transform::Pass InlineCompilerFunctions(Array global_vars) { + runtime::TypedPackedFunc pass_func = + [global_vars = std::move(global_vars)](IRModule mod, transform::PassContext ctx) { + VLOG(1) << "InlineCompilerFunctions with global_vars: " << PrettyPrint(global_vars); + if (global_vars.empty()) { + return mod; + } + VLOG(1) << "InlineCompilerFunctions input:" << std::endl << PrettyPrint(mod); + IRModule output_mod = mod->ShallowCopy(); + for (const auto& kv : mod->functions) { + if (std::find(global_vars.begin(), global_vars.end(), kv.first) != global_vars.end()) { + output_mod->Remove(kv.first); + } else if (const auto* function_node = AsOptimizableFunctionNode(kv.second)) { + Expr new_body = OuterInliner(mod, global_vars).VisitExpr(function_node->body); + Function new_function = + WithFields(GetRef(function_node), /*opt_params=*/{}, new_body); + output_mod->Add(kv.first, new_function); + } + } + VLOG(1) << "InlineCompilerFunctions result:" << std::endl << PrettyPrint(output_mod); + return output_mod; + }; + + return tvm::transform::CreateModulePass(pass_func, 0, "InlineCompilerFunctionsImpl", {}); +} + TVM_REGISTER_GLOBAL("relay._transform.OutlineCompilerFunctionsWithExistingGlobalSymbols") .set_body_typed(OutlineCompilerFunctionsWithExistingGlobalSymbols); TVM_REGISTER_GLOBAL("relay._transform.MarkCompilerFunctionsAsExtern") .set_body_typed(MarkCompilerFunctionsAsExtern); +TVM_REGISTER_GLOBAL("relay._transform.InlineCompilerFunctions") + .set_body_typed(InlineCompilerFunctions); } // namespace transforms } // namespace relay diff --git a/src/relay/transforms/compiler_function_utils.h b/src/relay/transforms/compiler_function_utils.h index 9d1dcd9f21a2..5cd89cbb9d2d 100644 --- a/src/relay/transforms/compiler_function_utils.h +++ b/src/relay/transforms/compiler_function_utils.h @@ -22,10 +22,10 @@ * \brief Helper passes for working with functions with the "Compiler" attribute. * * Those wishing to use the "RelayToTIR" custom pass machinery to do IRModule-at-a-time external - * codegen may find the following two helper passes useful: + * codegen may find the following helpers useful: * - * - \p OutlineCompilerFunctionsWithExistingGlobalSymbols will lift inline functions with a - * matching "Compiler" attribute to be global functions, using the "global_symbol" attribute + * - The \p OutlineCompilerFunctionsWithExistingGlobalSymbols pass will lift inline functions with + * a matching "Compiler" attribute to be global functions, using the "global_symbol" attribute * already assigned. Can be used before custom lowering. * * Note that ideally "Compiler" attributed functions would be made global functions as early as @@ -36,15 +36,22 @@ * * See also OutlineCompilerFunctionsMutator in src/relay/backend/contrib/ethosu/codegen.cc. * - * - (\p OutlineCompilerFunctions is a more general version of the above which can use a custom - * cache to both allocate "global_symbol" names and ensure two strucurally equal functions are - * assigned the same name, and thus lowered only once. This is used by Collage when preparing - * the optimally partitioned IRModule). + * - (The \p OutlineCompilerFunctions pass is a more general version of the above which can use + * a custom cache to both allocate "global_symbol" names and ensure two structurally equal + * functions are assigned the same name, and thus lowered only once. This is used by Collage + * when preparing the optimally partitioned IRModule). * - * - \p MarkCompilerFunctionsAsExtern will replace global functions with a matching "Compiler" - * attribute with the same function with just an "Extern" attribute, signalling the function - * has been dealt with. However calls to such functions will be left unchanged. Can be used - * after lowering to cleanup the IRModule. + * - The \p MarkCompilerFunctionsAsExtern pass will update the attributes of global functions + * with a matching "Compiler" attribute to have just the "Extern" attribute. That will signal + * the function has been dealt with. However calls to such functions will be left unchanged. + * Can be used after lowering to cleanup the IRModule. + * + * - The \p InlineCompilerFunctions pass can selectively inline global functions with a matching + * "Compiler" attribute who's name appears in the given set. Obviously it's more sensible to + * not create that function in the first place, however some external codegen have rules to + * accept or reject partitionings based on the overall partitioned function body. This pass + * can be used do the legwork, and will take care to not only inline the outer "Compiler" + * annotated funcition, but also any "Composite" annotated functions in its body. */ #ifndef TVM_RELAY_TRANSFORMS_COMPILER_FUNCTION_UTILS_H_ @@ -126,6 +133,16 @@ transform::Pass OutlineCompilerFunctionsWithExistingGlobalSymbols(std::string co */ transform::Pass MarkCompilerFunctionsAsExtern(std::string compiler_filter = ""); +/*! + * \brief A pass to inline all global "Compiler" functions which are bound to a global var + * in \p global_vars. Both the global function and any "Composite" functions it its body are + * inlined. + * + * This pass may be useful for external codegen which needs to undo partitioning based on + * properties of the entire partition. + */ +transform::Pass InlineCompilerFunctions(Array global_vars); + } // namespace transforms } // namespace relay } // namespace tvm diff --git a/tests/python/relay/transform/test_compiler_function_utils.py b/tests/python/relay/transform/test_compiler_function_utils.py index b9eb11547595..d2476f2361db 100644 --- a/tests/python/relay/transform/test_compiler_function_utils.py +++ b/tests/python/relay/transform/test_compiler_function_utils.py @@ -42,7 +42,7 @@ def make_consts(dtype, shapes): } -def inlined_mod(): +def original_mod(): return tvm.parser.parse( """ #[version = "0.0.5"] @@ -143,10 +143,35 @@ def @tvmgen_default_cutlass_main_0(%y_0_i0: Tensor[(1600, 768), float16], %y_0_i ) +def expected_inlined_mod(): + return tvm.parser.parse( + """ + #[version = "0.0.5"] + def @main(%x0 : Tensor[(1600, 768), float16], %x3 : Tensor[(600, 32, 64), float16]) -> (Tensor[(1600, 2304), float16], Tensor[(600, 32, 32), float16]) { + %0 = nn.dense(%x0, meta[relay.Constant][0], units=2304); + %1 = add(%0, meta[relay.Constant][1]); + %2 = fn(%y_3_i0: Tensor[(600, 32, 64), float16], %y_3_i1: Tensor[(600, 32, 64), float16], + Inline=1, Compiler="cublas", global_symbol="tvmgen_default_cublas_main_3", Primitive=1) -> Tensor[(600, 32, 32), float16] { + %6 = fn (%FunctionVar_0_01: Tensor[(600, 32, 64), float16], %FunctionVar_0_11: Tensor[(600, 32, 64), float16], + PartitionedFromPattern="nn.batch_matmul_", Composite="cublas.batch_matmul") -> Tensor[(600, 32, 32), float16] { + nn.batch_matmul(%FunctionVar_0_01, %FunctionVar_0_11, out_dtype="float16", transpose_b=True) + }; + %6(%y_3_i0, %y_3_i1) + }; + %3 = %2(%x3, meta[relay.Constant][2]); + (%1, %3) + } + """, + "from_string", + None, + metatable, + ) + + def test_outline_compiler_functions_with_existing_global_symbols(): actual_outlined_mod = tvm.relay.transform.OutlineCompilerFunctionsWithExistingGlobalSymbols( "cutlass" - )(inlined_mod()) + )(original_mod()) tvm.ir.assert_structural_equal(actual_outlined_mod, expected_outlined_mod(), map_free_vars=True) @@ -157,5 +182,12 @@ def test_mark_compiler_functions_as_extern(): tvm.ir.assert_structural_equal(actual_extern_mod, expected_extern_mod(), map_free_vars=True) +def test_inline_compiler_functions(): + mod = expected_outlined_mod() + gv = mod.get_global_var("tvmgen_default_cutlass_main_0") + actual_inlined_mod = tvm.relay.transform.InlineCompilerFunctions([gv])(mod) + tvm.ir.assert_structural_equal(actual_inlined_mod, expected_inlined_mod(), map_free_vars=True) + + if __name__ == "__main__": tvm.testing.main() From 72dcb3c5fb9d1b32795272b7fb09f6eb11ffa022 Mon Sep 17 00:00:00 2001 From: mbs-octoml Date: Mon, 27 Jun 2022 12:24:03 -0700 Subject: [PATCH 2/3] - Lints --- python/tvm/relay/op/contrib/tensorrt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relay/op/contrib/tensorrt.py b/python/tvm/relay/op/contrib/tensorrt.py index 55aef9a5f986..963ffd3a7f89 100644 --- a/python/tvm/relay/op/contrib/tensorrt.py +++ b/python/tvm/relay/op/contrib/tensorrt.py @@ -26,7 +26,7 @@ from tvm.relay import transform from tvm.relay.build_module import bind_params_by_name from tvm.relay.dataflow_pattern import is_op, wildcard, is_constant, is_tuple, is_tuple_get_item -from tvm.relay.expr import Call, Constant, GlobalVar, TupleGetItem +from tvm.relay.expr import Call, Constant, TupleGetItem from tvm.relay.expr_functor import ExprMutator, ExprVisitor from tvm.relay.op.contrib.register import register_pattern_table From 862e0e147023f469797adc9e11185584ac561dcb Mon Sep 17 00:00:00 2001 From: mbs-octoml Date: Mon, 27 Jun 2022 13:32:51 -0700 Subject: [PATCH 3/3] - Only AOT compilation paths ensure "executor" is provided as a Target attribute. --- python/tvm/relay/op/contrib/tensorrt.py | 2 +- python/tvm/relay/transform/transform.py | 8 ++++---- src/relay/transforms/compiler_function_utils.cc | 12 ++++++------ src/relay/transforms/compiler_function_utils.h | 6 +++--- .../relay/transform/test_compiler_function_utils.py | 2 +- 5 files changed, 15 insertions(+), 15 deletions(-) diff --git a/python/tvm/relay/op/contrib/tensorrt.py b/python/tvm/relay/op/contrib/tensorrt.py index 963ffd3a7f89..a69e2d410529 100644 --- a/python/tvm/relay/op/contrib/tensorrt.py +++ b/python/tvm/relay/op/contrib/tensorrt.py @@ -1107,7 +1107,7 @@ def prune_tensorrt_subgraphs(mod: tvm.IRModule) -> tvm.IRModule: and mod[gv].attrs["Compiler"] == "tensorrt" and not is_valid_subgraph(mod[gv].params, mod[gv].body) ] - return relay.transform.InlineCompilerFunctions(global_vars_to_inline)(mod) + return relay.transform.InlineCompilerFunctionsBoundTo(global_vars_to_inline)(mod) class RemoveDropout(ExprMutator): diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index 663a9291b2f8..979664f72ca3 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -1422,11 +1422,11 @@ def MarkCompilerFunctionsAsExtern(compiler_filter=""): return _ffi_api.MarkCompilerFunctionsAsExtern(compiler_filter) -def InlineCompilerFunctions(global_vars): +def InlineCompilerFunctionsBoundTo(global_vars): """Inlines all global functions bound to a global var in global_vars. - Both the global "Compiler" attributed function, and any "Composite" functions it its body are - inlined. + Both the global "Compiler" attributed function, and any calls to "Composite" functions it its + body are inlined. This pass may be useful for external codegen which needs to undo partitioning based on properties of the entire partition. @@ -1441,4 +1441,4 @@ def InlineCompilerFunctions(global_vars): ret : tvm.transform.Pass The pass. """ - return _ffi_api.InlineCompilerFunctions(global_vars) + return _ffi_api.InlineCompilerFunctionsBoundTo(global_vars) diff --git a/src/relay/transforms/compiler_function_utils.cc b/src/relay/transforms/compiler_function_utils.cc index 4fc8a01ed75a..1b0f002f1def 100644 --- a/src/relay/transforms/compiler_function_utils.cc +++ b/src/relay/transforms/compiler_function_utils.cc @@ -230,10 +230,10 @@ transform::Pass MarkCompilerFunctionsAsExtern(std::string compiler_filter) { return tvm::transform::CreateModulePass(pass_func, 0, "MarkCompilerFunctionsAsExtern", {}); } -transform::Pass InlineCompilerFunctions(Array global_vars) { +transform::Pass InlineCompilerFunctionsBoundTo(Array global_vars) { runtime::TypedPackedFunc pass_func = [global_vars = std::move(global_vars)](IRModule mod, transform::PassContext ctx) { - VLOG(1) << "InlineCompilerFunctions with global_vars: " << PrettyPrint(global_vars); + VLOG(1) << "InlineCompilerFunctionsBoundTo with global_vars: " << PrettyPrint(global_vars); if (global_vars.empty()) { return mod; } @@ -249,19 +249,19 @@ transform::Pass InlineCompilerFunctions(Array global_vars) { output_mod->Add(kv.first, new_function); } } - VLOG(1) << "InlineCompilerFunctions result:" << std::endl << PrettyPrint(output_mod); + VLOG(1) << "InlineCompilerFunctionsBoundTo result:" << std::endl << PrettyPrint(output_mod); return output_mod; }; - return tvm::transform::CreateModulePass(pass_func, 0, "InlineCompilerFunctionsImpl", {}); + return tvm::transform::CreateModulePass(pass_func, 0, "InlineCompilerFunctionsBoundTo", {}); } TVM_REGISTER_GLOBAL("relay._transform.OutlineCompilerFunctionsWithExistingGlobalSymbols") .set_body_typed(OutlineCompilerFunctionsWithExistingGlobalSymbols); TVM_REGISTER_GLOBAL("relay._transform.MarkCompilerFunctionsAsExtern") .set_body_typed(MarkCompilerFunctionsAsExtern); -TVM_REGISTER_GLOBAL("relay._transform.InlineCompilerFunctions") - .set_body_typed(InlineCompilerFunctions); +TVM_REGISTER_GLOBAL("relay._transform.InlineCompilerFunctionsBoundTo") + .set_body_typed(InlineCompilerFunctionsBoundTo); } // namespace transforms } // namespace relay diff --git a/src/relay/transforms/compiler_function_utils.h b/src/relay/transforms/compiler_function_utils.h index 5cd89cbb9d2d..6664594fc0a0 100644 --- a/src/relay/transforms/compiler_function_utils.h +++ b/src/relay/transforms/compiler_function_utils.h @@ -135,13 +135,13 @@ transform::Pass MarkCompilerFunctionsAsExtern(std::string compiler_filter = ""); /*! * \brief A pass to inline all global "Compiler" functions which are bound to a global var - * in \p global_vars. Both the global function and any "Composite" functions it its body are - * inlined. + * in \p global_vars. Both the global function and any calls to "Composite" functions it its body + * are inlined. * * This pass may be useful for external codegen which needs to undo partitioning based on * properties of the entire partition. */ -transform::Pass InlineCompilerFunctions(Array global_vars); +transform::Pass InlineCompilerFunctionsBoundTo(Array global_vars); } // namespace transforms } // namespace relay diff --git a/tests/python/relay/transform/test_compiler_function_utils.py b/tests/python/relay/transform/test_compiler_function_utils.py index d2476f2361db..66abeff8ab29 100644 --- a/tests/python/relay/transform/test_compiler_function_utils.py +++ b/tests/python/relay/transform/test_compiler_function_utils.py @@ -185,7 +185,7 @@ def test_mark_compiler_functions_as_extern(): def test_inline_compiler_functions(): mod = expected_outlined_mod() gv = mod.get_global_var("tvmgen_default_cutlass_main_0") - actual_inlined_mod = tvm.relay.transform.InlineCompilerFunctions([gv])(mod) + actual_inlined_mod = tvm.relay.transform.InlineCompilerFunctionsBoundTo([gv])(mod) tvm.ir.assert_structural_equal(actual_inlined_mod, expected_inlined_mod(), map_free_vars=True)