Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 52 additions & 31 deletions src/relay/transforms/partition_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -261,37 +261,26 @@ class Partitioner : public MixedModeMutator {
}

/*!
* \brief Create a function and its function call for the given region. If the function has
* multiple outputs, a Tuple will be formed to aggregate all outputs, and TupleGetItem nodes
* will be created to serve output consumers.
* \brief Check if an expr is a constant or a tuple that only contain constants.
*/
void CreateFunction(AnnotatedRegion region, const CallNode* end_node) {
// Create fields which is a unique list of outputs.
Array<Expr> fields;
std::unordered_map<Expr, int, ObjectPtrHash, ObjectPtrEqual> out_expr_to_idx;
int out_idx = 0;
for (auto region_end_node : region->GetOutputs()) {
auto ret_node = Downcast<Call>(region_end_node)->args[0];
// Don't duplicate outputs.
if (!out_expr_to_idx.count(ret_node)) {
auto ret_expr = MixedModeMutator::VisitExpr(ret_node);
fields.push_back(ret_expr);
out_expr_to_idx[ret_node] = out_idx++;
}
}
bool IsConstant(const Expr& expr) const {
if (expr->IsInstance<ConstantNode>()) return true;
if (!expr->IsInstance<TupleNode>()) return false;
const auto* tn = expr.as<TupleNode>();
return std::all_of(tn->fields.begin(), tn->fields.end(),
[](const Expr& e) { return e->IsInstance<ConstantNode>(); });
}

/*!
* \brief Create a call to the function that represents a region.
* \note The customized optimization pipeline will be invoked as well to
* optimize each function that is handled by external codegen.
*/
Call CreateRegionCall(AnnotatedRegion region, const Array<Expr>& fields,
const CallNode* end_node) {
Array<Var> params;
Array<Expr> param_expr;
Map<Var, Expr> params_bind;

auto IsConstant = [](const Expr& expr) {
if (expr->IsInstance<ConstantNode>()) return true;
if (!expr->IsInstance<TupleNode>()) return false;
const auto* tn = expr.as<TupleNode>();
return std::all_of(tn->fields.begin(), tn->fields.end(),
[](const Expr& e) { return e->IsInstance<ConstantNode>(); });
};

for (auto pair : region_func_meta_[region].args) {
params.push_back(pair.first);
if (IsConstant(pair.second)) {
Expand All @@ -314,18 +303,25 @@ class Partitioner : public MixedModeMutator {
std::string target = end_node->attrs.as<CompilerAttrs>()->compiler;
std::string name = target + "_" + std::to_string(region->GetID());

// Constant propagation
if (!params_bind.empty()) {
global_region_func = Downcast<Function>(relay::Bind(global_region_func, params_bind));
}
std::string ext_opt = "relay.ext." + target + ".optimize";
auto pf = tvm::runtime::Registry::Get(ext_opt);
if (pf != nullptr) {
auto mod = IRModule::FromExpr(global_region_func);
mod = (*pf)(mod);
global_region_func = Downcast<Function>(mod->Lookup("main"));
}

global_region_func =
WithAttr(std::move(global_region_func), tvm::attr::kGlobalSymbol, runtime::String(name));
global_region_func = WithAttr(std::move(global_region_func), attr::kPrimitive, tvm::Integer(1));
global_region_func =
WithAttr(std::move(global_region_func), attr::kCompiler, tvm::runtime::String(target));
global_region_func = WithAttr(std::move(global_region_func), attr::kInline, tvm::Integer(1));

// Constant propagation
if (!params_bind.empty()) {
global_region_func = Downcast<Function>(relay::Bind(global_region_func, params_bind));
}

std::string fname = name;
CHECK(!module_->ContainGlobalVar(fname)) << "Global function " << fname << " already exists";
// Create a global function and add it to the IRModule for the region.
Expand All @@ -340,6 +336,31 @@ class Partitioner : public MixedModeMutator {
auto call = Call(glob_func, param_expr);
region_func_meta_[region].func_call = call;

return call;
}

/*!
* \brief Create a function and its function call for the given region. If the function has
* multiple outputs, a Tuple will be formed to aggregate all outputs, and TupleGetItem nodes
* will be created to serve output consumers.
*/
void CreateFunction(AnnotatedRegion region, const CallNode* end_node) {
// Create fields which is a unique list of outputs.
Array<Expr> fields;
std::unordered_map<Expr, int, ObjectPtrHash, ObjectPtrEqual> out_expr_to_idx;
int out_idx = 0;
for (auto region_end_node : region->GetOutputs()) {
auto ret_node = Downcast<Call>(region_end_node)->args[0];
// Don't duplicate outputs.
if (!out_expr_to_idx.count(ret_node)) {
auto ret_expr = MixedModeMutator::VisitExpr(ret_node);
fields.push_back(ret_expr);
out_expr_to_idx[ret_node] = out_idx++;
}
}

Call call = CreateRegionCall(region, fields, end_node);

// Create output expr(s) for the function call.
if (out_expr_to_idx.size() == 1) {
// Single output direcly uses the call node as the output expr.
Expand Down
32 changes: 32 additions & 0 deletions tests/python/relay/test_pass_partition_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -1286,6 +1286,37 @@ def test_tuple_output_exec():
[(10, 10), (10, 10)],
[(a_data + b_data), (a_data - b_data)])

def test_extern_opt():
def Optimize(mod):
return relay.transform.FoldConstant()(mod)

tvm.register_func("relay.ext.test_target.optimize", Optimize)

x = relay.var('x', shape=(2, 2))
y0 = relay.var('y0', shape=(2, 2))
y1 = relay.var('y1', shape=(2, 2))
yy0 = relay.annotation.compiler_begin(y0, 'test_target')
yy1 = relay.annotation.compiler_begin(y1, 'test_target')
z = yy0 + yy1
end = relay.annotation.compiler_end(z, 'test_target')
f = relay.Function([x, y0, y1], end * x)
c = np.ones(shape=(2, 2), dtype="float32")
f = bind_params_by_name(f, {"y0": tvm.nd.array(c), "y1": tvm.nd.array(c)})
mod = tvm.IRModule()
mod["main"] = f
mod = transform.PartitionGraph()(mod)

try:
t0 = mod["test_target_0"]
except:
raise KeyError("test_target_0 not found")

assert isinstance(t0.body, relay.Constant)
expected = np.empty([2, 2])
expected.fill(2)
tvm.testing.assert_allclose(t0.body.data.asnumpy(), expected, rtol=1e-5,
atol=1e-5)

if __name__ == "__main__":
test_multi_node_compiler()
test_extern_ccompiler_single_op()
Expand All @@ -1305,3 +1336,4 @@ def test_tuple_output_exec():
test_constant_tuples()
test_flatten_tuple_output()
test_tuple_output_exec()
test_extern_opt()