Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion include/tvm/relax/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -492,12 +492,15 @@ TVM_DLL Pass Gradient(String func_name, Optional<Array<Var>> require_grads = Nul
* corresponding pattern name. For example, "dnnl" if the pattern name is "dnnl.conv2d_relu".
* This must be True if the created composite functions are intended to be offloaded to
* an external backend without using the MergeCompositeFunctions pass.
* \param entry_function_names The names of functions that should be considered as entry points. If
* not specified, all externally exposed functions will be considered as entry points.
* \return The Pass.
*
* \note Only operates within dataflow blocks. ConvertToDataflow may need to be called first.
*/
TVM_DLL Pass FuseOpsByPattern(const tvm::Array<FusionPattern>& patterns, bool bind_constants = true,
bool annotate_codegen = false);
bool annotate_codegen = false,
const tvm::Array<String>& entry_function_names = {});

/*!
* \brief Group one or multiple composite functions created by FuseOpsByPattern into a new
Expand Down
5 changes: 5 additions & 0 deletions python/tvm/relax/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -890,6 +890,7 @@ def FuseOpsByPattern(
patterns: List[Union[FusionPattern, Tuple]],
bind_constants: bool = True,
annotate_codegen: bool = False,
entry_functions: Optional[List[str]] = None,
) -> tvm.ir.transform.Pass:
"""Apply pattern matching to each function in the given module, and group matched expressions
into a new function.
Expand Down Expand Up @@ -919,6 +920,9 @@ def FuseOpsByPattern(
This must be True if the created composite functions are intended to be offloaded to
an external backend without using the MergeCompositeFunctions pass.

entry_functions : Optional[List[str]]
The set of entry functions to start from.

Returns
-------
ret : tvm.transform.Pass
Expand All @@ -938,6 +942,7 @@ def FuseOpsByPattern(
converted_patterns,
bind_constants,
annotate_codegen,
entry_functions or [],
) # type: ignore


Expand Down
69 changes: 46 additions & 23 deletions src/relax/transform/fuse_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -690,8 +690,16 @@ class OperatorFusor : public ExprMutator {
* \brief The main transformation on the IRModule
* \return The new IRModule after transformation
*/
IRModule Transform() {
for (const auto& gv : mod_->GetGlobalVars()) {
IRModule Transform(const Array<String>& entry_function_names = {}) {
Array<GlobalVar> entry_functions;
if (entry_function_names.empty()) {
entry_functions = mod_->GetGlobalVars();
} else {
for (const auto& name : entry_function_names) {
entry_functions.push_back(mod_->GetGlobalVar(name));
}
}
for (const auto& gv : entry_functions) {
const auto& func = mod_->Lookup(gv);
// Only visit Relax function without attr kPrimitive.
if (func->IsInstance<relax::FunctionNode>() && !func->HasNonzeroAttr(attr::kPrimitive)) {
Expand Down Expand Up @@ -1023,8 +1031,8 @@ IRModule FuseOps(IRModule mod, int opt_level, size_t max_fuse_depth) {

IRModule MakeGroupedFunctions(
IRModule mod, const std::unordered_map<const Object*, GraphPartitioner::Group*>& partition,
bool lift_constants) {
return OperatorFusor(mod, partition, lift_constants).Transform();
bool lift_constants, const Array<String>& entry_function_names) {
return OperatorFusor(mod, partition, lift_constants).Transform(entry_function_names);
}

/*! \brief Create a "partitioning", a map from interior / leaf expr to its representative group,
Expand Down Expand Up @@ -1269,26 +1277,39 @@ class CompositeFunctionAnnotator : public ExprMutator {
};

IRModule FuseOpsByPattern(const tvm::Array<transform::FusionPattern>& patterns, IRModule mod,
bool bind_constants, bool annotate_codegen) {
bool bind_constants, bool annotate_codegen,
Array<String> entry_function_names) {
support::Arena arena;

for (const auto& pattern : patterns) {
OperatorFusor::GroupMap group_map;
for (const auto& gv : mod->GetGlobalVars()) {
const auto& base_func = mod->Lookup(gv);
if (base_func->IsInstance<tir::PrimFuncNode>()) {
continue;
Array<Function> entry_functions;
if (entry_function_names.size()) {
for (const auto& name : entry_function_names) {
auto gv = mod->GetGlobalVar(name);
auto func = mod->Lookup(gv);
ICHECK(func->IsInstance<FunctionNode>()) << "Entry function must be a relax function";
entry_functions.push_back(Downcast<Function>(func));
}
const FunctionNode* function = base_func.as<FunctionNode>();
if (function->GetAttr<Integer>(attr::kPrimitive).defined() ||
function->GetAttr<String>(attr::kComposite).defined() ||
function->GetAttr<String>(attr::kCodegen).defined()) {
continue;
} else {
for (const auto& gv : mod->GetGlobalVars()) {
const auto& base_func = mod->Lookup(gv);
if (base_func->IsInstance<tir::PrimFuncNode>()) {
continue;
}
const FunctionNode* function = base_func.as<FunctionNode>();
if (function->GetAttr<Integer>(attr::kPrimitive).defined() ||
function->GetAttr<String>(attr::kComposite).defined() ||
function->GetAttr<String>(attr::kCodegen).defined()) {
continue;
}
entry_functions.push_back(Downcast<Function>(base_func));
}

auto map = PatternBasedPartitioner::Run(pattern->name, pattern->pattern,
pattern->annotation_patterns,
pattern->check.value_or(nullptr), base_func, &arena,
pattern->attrs_getter.value_or(nullptr));
}
OperatorFusor::GroupMap group_map;
for (const auto& func : entry_functions) {
auto map = PatternBasedPartitioner::Run(
pattern->name, pattern->pattern, pattern->annotation_patterns,
pattern->check.value_or(nullptr), func, &arena, pattern->attrs_getter.value_or(nullptr));
for (const auto& [key, value] : map) {
CHECK(!group_map.count(key))
<< "ValueError: "
Expand All @@ -1298,7 +1319,8 @@ IRModule FuseOpsByPattern(const tvm::Array<transform::FusionPattern>& patterns,
group_map.insert({key, value});
}
}
mod = MakeGroupedFunctions(mod, group_map, /*lift_constants*/ !bind_constants);
mod = MakeGroupedFunctions(mod, group_map, /*lift_constants*/ !bind_constants,
entry_function_names);
}
if (annotate_codegen) {
return CompositeFunctionAnnotator(mod).Run();
Expand Down Expand Up @@ -1358,10 +1380,11 @@ Pass FuseOps(int fuse_opt_level) {
TVM_REGISTER_GLOBAL("relax.transform.FuseOps").set_body_typed(FuseOps);

Pass FuseOpsByPattern(const tvm::Array<FusionPattern>& patterns, bool bind_constants,
bool annotate_codegen) {
bool annotate_codegen, const Array<String>& entry_function_names) {
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func = //
[=](IRModule m, PassContext pc) {
return relax::FuseOpsByPattern(patterns, m, bind_constants, annotate_codegen);
return relax::FuseOpsByPattern(patterns, m, bind_constants, annotate_codegen,
entry_function_names);
};
return CreateModulePass(/*pass_function=*/pass_func, //
/*opt_level=*/0, //
Expand Down
3 changes: 2 additions & 1 deletion src/relax/transform/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -137,12 +137,13 @@ inline std::string GetExtSymbol(const Function& func) {
* \param partition A mapping from a subexpression to the containing group.
* \param lift_constants Whether or not to lift bound constants to parameters of the
* grouped function.
* \param entry_function_names The names of the entry functions.
* \return A new module containing grouped functions.
*/
IRModule MakeGroupedFunctions(
IRModule mod,
const std::unordered_map<const Object*, relay::GraphPartitioner::Group*>& partition,
bool lift_constants = true);
bool lift_constants = true, const Array<String>& entry_function_names = {});

/*!
* \brief Check if the given StructInfo is a scalar tensor. The sinfo should be an instance of
Expand Down