Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 24 additions & 53 deletions python/tvm/relay/op/contrib/tensorrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from tvm.relay import transform
from tvm.relay.build_module import bind_params_by_name
from tvm.relay.dataflow_pattern import is_op, wildcard, is_constant, is_tuple, is_tuple_get_item
from tvm.relay.expr import Call, Constant, GlobalVar, TupleGetItem
from tvm.relay.expr import Call, Constant, TupleGetItem
from tvm.relay.expr_functor import ExprMutator, ExprVisitor
from tvm.relay.op.contrib.register import register_pattern_table

Expand Down Expand Up @@ -864,7 +864,11 @@ def pattern_table() -> List[
binary_op_pattern_with_const("nn.dense"),
make_predicate(dense_checker),
),
("tensorrt.bias_add", binary_op_pattern("nn.bias_add"), make_predicate(bias_add_checker)),
(
"tensorrt.nn.bias_add",
binary_op_pattern("nn.bias_add"),
make_predicate(bias_add_checker),
),
(
"tensorrt.nn.batch_matmul",
binary_op_pattern("nn.batch_matmul"),
Expand Down Expand Up @@ -1062,7 +1066,6 @@ def is_valid_subgraph(params: List[relay.expr.Var], body: relay.expr.Expr) -> bo
for var in params:
# In implicit batch mode, all inputs must have same batch size
# TODO: (codeislife99) : Fix different dynamic batch size inputs

if isinstance(var.checked_type, relay.TupleType):
for tupe_type in var.checked_type.fields:
# Scalar inputs not allowed
Expand All @@ -1079,64 +1082,32 @@ def is_valid_subgraph(params: List[relay.expr.Var], body: relay.expr.Expr) -> bo
return False
if not isinstance(var.checked_type.shape[0], tvm.tir.expr.Any):
input_batch_sizes.append(int(var.checked_type.shape[0]))

if len(input_batch_sizes) > 1 and len(set(input_batch_sizes)) != 1:
logger.info("tensorrt: inputs have different batch sizes")
logger.info("tensorrt: inputs have different batch sizes: %s", input_batch_sizes)
return False

if get_tensorrt_remove_no_mac_subgraphs():
return IsComputeIntensiveGraph().is_graph_compute_intensive(body)
if not IsComputeIntensiveGraph().is_graph_compute_intensive(body):
logger.info("tensorrt: not a compute-intensize sub-graph")
return False

return True


def prune_tensorrt_subgraphs(mod: tvm.IRModule) -> tvm.IRModule:
"""
Removes invalid subgraphs and those with no multiply-accumulates (if remove_no_max_subgraphs
is set).
"""

class SubgraphRemover(ExprMutator):
"""
Reverts subgraphs in subgraphs_to_remove back to TVM instead of using an external codegen.
"""

def __init__(
self, subgraphs_to_remove: List[str], mod: tvm.IRModule, new_mod: tvm.IRModule
) -> None:
ExprMutator.__init__(self)
self.subgraphs_to_remove = subgraphs_to_remove
self.mod = mod
self.new_mod = new_mod

def visit_call(self, call: relay.expr.Call) -> relay.expr.Expr:
if isinstance(call.op, GlobalVar):
name = call.op.name_hint
if name in self.subgraphs_to_remove:
# "Inline" the subgraph back into new main function.
func = self.mod[name]
var_map = {}
for arg, param in zip(call.args, func.params):
var_map[param] = super().visit(arg)
new_body = relay.bind(func.body, var_map)
return new_body
if name != "main":
args = []
for arg in call.args:
args.append(super().visit(arg))
return call.op(*args)
return super().visit_call(call)

subgraphs_to_remove: List[str] = []
# Remove invalid subgraphs
for subgraph in mod.get_global_vars():
name = subgraph.name_hint
if not mod[name].attrs or mod[name].attrs["Compiler"] != "tensorrt":
continue
if not is_valid_subgraph(mod[name].params, mod[name].body):
subgraphs_to_remove.append(name)
# Create new pruned module
new_mod = tvm.IRModule(mod.functions, mod.type_definitions)
new_mod["main"] = SubgraphRemover(subgraphs_to_remove, mod, new_mod).visit(mod["main"])
new_mod = transform.RemoveUnusedFunctions()(new_mod)
return new_mod
Un-partition those partitions which:
- have no multiply-accumulates (if remove_no_mac_subgraphs is True)
- can't actually be supported by TensorRT now that we see the whole partition."""
global_vars_to_inline = [
gv
for gv in mod.get_global_vars()
if mod[gv].attrs
and mod[gv].attrs["Compiler"] == "tensorrt"
and not is_valid_subgraph(mod[gv].params, mod[gv].body)
]
return relay.transform.InlineCompilerFunctionsBoundTo(global_vars_to_inline)(mod)


class RemoveDropout(ExprMutator):
Expand Down
22 changes: 22 additions & 0 deletions python/tvm/relay/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1420,3 +1420,25 @@ def MarkCompilerFunctionsAsExtern(compiler_filter=""):
The pass.
"""
return _ffi_api.MarkCompilerFunctionsAsExtern(compiler_filter)


def InlineCompilerFunctionsBoundTo(global_vars):
"""Inlines all global functions bound to a global var in global_vars.

Both the global "Compiler" attributed function, and any calls to "Composite" functions it its
body are inlined.

This pass may be useful for external codegen which needs to undo partitioning based on
properties of the entire partition.

Parameters
----------
global_vars : Array[tvm.relay.GlobalVar]
The global vars of all 'Compiler' functions to inline.

Returns
-------
ret : tvm.transform.Pass
The pass.
"""
return _ffi_api.InlineCompilerFunctionsBoundTo(global_vars)
175 changes: 140 additions & 35 deletions src/relay/transforms/compiler_function_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,28 @@
#include "../op/call/call.h"
#include "tvm/relay/analysis.h"
#include "tvm/relay/expr_functor.h"
#include "tvm/relay/transform.h"

namespace tvm {
namespace relay {
namespace transforms {
namespace {

/*!
* \brief Returns the \p FunctionNode of if \p expr if it is a "Compiler" function which should
* be processed by a pass using \p compiler_filter. Otherwise returns null.
*/
const FunctionNode* AsFunctionNode(const Expr& expr, const std::string& compiler_filter) {
if (const auto* function_node = expr.as<FunctionNode>()) {
Optional<String> opt_compiler = function_node->GetAttr<String>(attr::kCompiler);
if (opt_compiler.defined() &&
(compiler_filter.empty() || opt_compiler.value() == compiler_filter)) {
return function_node;
}
}
return nullptr;
}

/*!
* \brief Rewrite calls to inlined "Compiler" functions to global functions. The given
* module will be extended with the newly outlined functions.
Expand All @@ -44,35 +60,31 @@ class Outliner : public MixedModeMutator {

Expr Rewrite_(const CallNode* pre, const Expr& post) final {
Call new_call = Downcast<Call>(post);
if (const auto* function_node = new_call->op.as<FunctionNode>()) {
Optional<String> opt_compiler = function_node->GetAttr<String>(attr::kCompiler);
if (opt_compiler.defined() &&
(compiler_filter_.empty() || opt_compiler.value() == compiler_filter_)) {
auto function = GetRef<Function>(function_node);
DCHECK(FreeVars(function).empty()) << "Function marked with '" << attr::kCompiler
<< "' attribute should not have free variables";
// Ask the cache to supply a unique global var for this function.
GlobalVar global_symbol = cache_->GetGlobalSymbol(function);
// Depending on the cache's implementation, two structurally equal (but not object equal)
// functions may be assigned the same global symbol. If so we'll lift it just once, but
// rewrite all the calls.
if (!mod_->ContainGlobalVar(global_symbol->name_hint)) {
function =
WithAttr(std::move(function), tvm::attr::kGlobalSymbol, global_symbol->name_hint);
mod_->Add(global_symbol, function);
}
// Update the call.
return WithFields(new_call, global_symbol);
if (const auto* function_node = AsFunctionNode(new_call->op, compiler_filter_)) {
auto function = GetRef<Function>(function_node);
DCHECK(FreeVars(function).empty()) << "Function marked with '" << attr::kCompiler
<< "' attribute should not have free variables";
// Ask the cache to supply a unique global var for this function.
GlobalVar global_symbol = cache_->GetGlobalSymbol(function);
// Depending on the cache's implementation, two structurally equal (but not object
// equal) functions may be assigned the same global symbol. If so we'll lift it just
// once, but rewrite all the calls.
if (!mod_->ContainGlobalVar(global_symbol->name_hint)) {
function =
WithAttr(std::move(function), tvm::attr::kGlobalSymbol, global_symbol->name_hint);
mod_->Add(global_symbol, function);
}
// Update the call.
return WithFields(new_call, global_symbol);
}
return post;
}

private:
/*!
* \brief A cached mapping from functions to global variables. Depending on the implementation
* the cache may generate fresh symbols or require the function to already have a "global_symbol"
* attribute, and may share symbols between structurally equal functions.
* the cache may generate fresh symbols or require the function to already have a
* "global_symbol" attribute, and may share symbols between structurally equal functions.
*/
GlobalSymbolCache* cache_;
/*! \brief If non-empty, the "Compiler" attribute value to require on functions to outline. */
Expand All @@ -81,6 +93,72 @@ class Outliner : public MixedModeMutator {
IRModule mod_;
};

/*!
* \brief Inline immediate calls to "Composite" functions.
*/
class InnerInliner : public MixedModeMutator {
public:
InnerInliner() = default;

private:
using MixedModeMutator::Rewrite_;

Expr Rewrite_(const CallNode* pre, const Expr& post) final {
Call new_call = Downcast<Call>(post);
if (const auto* function_node = new_call->op.as<FunctionNode>()) {
ICHECK(function_node->GetAttr<String>(attr::kComposite).defined());
ICHECK_EQ(function_node->params.size(), new_call->args.size());
Map<Var, Expr> subst;
for (size_t i = 0; i < new_call->args.size(); ++i) {
subst.Set(function_node->params[i], new_call->args[i]);
}
return Bind(function_node->body, subst);
}
return post;
}
};

/*!
* \brief Inline calls to global "Compiler" functions with global var in \p global_vars.
* Both the 'outer' "Compiler" function and any 'inner' "Composite" functions in its body
* are inlined.
*/
class OuterInliner : public MixedModeMutator {
public:
OuterInliner(IRModule mod, Array<GlobalVar> global_vars_)
: mod_(std::move(mod)), global_vars_(std::move(global_vars_)) {}

private:
using MixedModeMutator::Rewrite_;

Expr Rewrite_(const CallNode* pre, const Expr& post) final {
Call new_call = Downcast<Call>(post);
if (const auto* global_var_node = new_call->op.as<GlobalVarNode>()) {
auto global_var = GetRef<GlobalVar>(global_var_node);
if (std::find(global_vars_.begin(), global_vars_.end(), global_var) != global_vars_.end()) {
BaseFunc base_func = mod_->Lookup(global_var);
const auto* function_node = base_func.as<FunctionNode>();
ICHECK(function_node);
ICHECK(function_node->GetAttr<String>(attr::kCompiler).defined());
ICHECK_EQ(function_node->params.size(), new_call->args.size());
Map<Var, Expr> subst;
for (size_t i = 0; i < new_call->args.size(); ++i) {
subst.Set(function_node->params[i], new_call->args[i]);
}
Expr new_body = InnerInliner().VisitExpr(function_node->body);
return Bind(new_body, subst);
}
}
return post;
}

private:
/*! \brief Original module we are processing. */
IRModule mod_;
/*! \brief Global vars of functions to inline. */
Array<GlobalVar> global_vars_;
};

} // namespace

GlobalSymbolCache::~GlobalSymbolCache() = default;
Expand All @@ -106,17 +184,18 @@ transform::Pass OutlineCompilerFunctions(std::shared_ptr<GlobalSymbolCache> cach
runtime::TypedPackedFunc<IRModule(IRModule, transform::PassContext)> pass_func =
[cache = std::move(cache), compiler_filter = std::move(compiler_filter)](
IRModule mod, transform::PassContext ctx) {
IRModule output_mod = GetRef<IRModule>(mod.CopyOnWrite());
VLOG(1) << "OutlineCompilerFunctions input:" << std::endl << PrettyPrint(mod);
IRModule output_mod = mod->ShallowCopy();
for (const auto& kv : mod->functions) {
const FunctionNode* function_node = AsOptimizableFunctionNode(kv.second);
if (function_node) {
if (const auto* function_node = AsOptimizableFunctionNode(kv.second)) {
Expr new_body =
Outliner(cache.get(), compiler_filter, output_mod).VisitExpr(function_node->body);
Function new_function =
WithFields(GetRef<Function>(function_node), /*opt_params=*/{}, new_body);
output_mod->Add(kv.first, new_function);
}
}
VLOG(1) << "OutlineCompilerFunctions result:" << std::endl << PrettyPrint(output_mod);
return output_mod;
};

Expand All @@ -132,31 +211,57 @@ transform::Pass OutlineCompilerFunctionsWithExistingGlobalSymbols(std::string co
transform::Pass MarkCompilerFunctionsAsExtern(std::string compiler_filter) {
runtime::TypedPackedFunc<IRModule(IRModule, transform::PassContext)> pass_func =
[compiler_filter = std::move(compiler_filter)](IRModule mod, transform::PassContext ctx) {
VLOG(1) << "MarkCompilerFunctionsAsExtern input:" << std::endl << PrettyPrint(mod);
IRModule output_mod = mod->ShallowCopy();
for (const auto& kv : mod->functions) {
if (const auto* function_node = kv.second.as<FunctionNode>()) {
Optional<String> opt_compiler = function_node->GetAttr<String>(attr::kCompiler);
if (opt_compiler.defined() &&
(compiler_filter.empty() || opt_compiler.value() == compiler_filter)) {
auto new_function = WithFields(
GetRef<Function>(function_node), function_node->params, function_node->body,
function_node->ret_type, function_node->type_params,
/* erase attributes */ DictAttrs(Map<String, ObjectRef>()));
new_function = WithAttr(std::move(new_function), attr::kExtern, Integer(1));
output_mod->Update(kv.first, new_function);
}
if (const auto* function_node = AsFunctionNode(kv.second, compiler_filter)) {
auto new_function =
WithFields(GetRef<Function>(function_node), function_node->params,
function_node->body, function_node->ret_type, function_node->type_params,
/* erase attributes */ DictAttrs(Map<String, ObjectRef>()));
new_function = WithAttr(std::move(new_function), attr::kExtern, Integer(1));
output_mod->Update(kv.first, new_function);
}
}
VLOG(1) << "MarkCompilerFunctionsAsExtern result:" << std::endl << PrettyPrint(output_mod);
return output_mod;
};

return tvm::transform::CreateModulePass(pass_func, 0, "MarkCompilerFunctionsAsExtern", {});
}

transform::Pass InlineCompilerFunctionsBoundTo(Array<GlobalVar> global_vars) {
runtime::TypedPackedFunc<IRModule(IRModule, transform::PassContext)> pass_func =
[global_vars = std::move(global_vars)](IRModule mod, transform::PassContext ctx) {
VLOG(1) << "InlineCompilerFunctionsBoundTo with global_vars: " << PrettyPrint(global_vars);
if (global_vars.empty()) {
return mod;
}
VLOG(1) << "InlineCompilerFunctions input:" << std::endl << PrettyPrint(mod);
IRModule output_mod = mod->ShallowCopy();
for (const auto& kv : mod->functions) {
if (std::find(global_vars.begin(), global_vars.end(), kv.first) != global_vars.end()) {
output_mod->Remove(kv.first);
} else if (const auto* function_node = AsOptimizableFunctionNode(kv.second)) {
Expr new_body = OuterInliner(mod, global_vars).VisitExpr(function_node->body);
Function new_function =
WithFields(GetRef<Function>(function_node), /*opt_params=*/{}, new_body);
output_mod->Add(kv.first, new_function);
}
}
VLOG(1) << "InlineCompilerFunctionsBoundTo result:" << std::endl << PrettyPrint(output_mod);
return output_mod;
};

return tvm::transform::CreateModulePass(pass_func, 0, "InlineCompilerFunctionsBoundTo", {});
}

TVM_REGISTER_GLOBAL("relay._transform.OutlineCompilerFunctionsWithExistingGlobalSymbols")
.set_body_typed(OutlineCompilerFunctionsWithExistingGlobalSymbols);
TVM_REGISTER_GLOBAL("relay._transform.MarkCompilerFunctionsAsExtern")
.set_body_typed(MarkCompilerFunctionsAsExtern);
TVM_REGISTER_GLOBAL("relay._transform.InlineCompilerFunctionsBoundTo")
.set_body_typed(InlineCompilerFunctionsBoundTo);

} // namespace transforms
} // namespace relay
Expand Down
Loading