Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 52 additions & 21 deletions src/relay/pass/fuse_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -232,8 +232,11 @@ class IndexedForwardGraph::Creator : private ExprVisitor {
}

void VisitExpr_(const TupleNode* op) {
CHECK(graph_.node_map.count(op));
Node* tuple_node = graph_.node_map.at(op);
tuple_node->pattern = kInjective;
for (const Expr& field : op->fields) {
this->Update(field, nullptr, kOpaque);
this->Update(field, tuple_node, kInjective);
}
ExprVisitor::VisitExpr_(op);
this->AddNode(op);
Expand Down Expand Up @@ -712,32 +715,15 @@ class FuseMutator : private ExprMutator {
// then we must have a group assignment for it already.
CHECK(gmap_.count(call));
auto* ret_group = gmap_.at(call)->FindRoot();
Array<Expr> new_args;
for (auto arg : call->args) {
auto type = arg->checked_type();
CHECK(gmap_.count(arg.get()))
<< "cannot find group of " << arg;
auto* arg_group = gmap_.at(arg.get())->FindRoot();
Expr new_arg = this->Mutate(arg);

if (ret_group != arg_group) {
Var param = ginfo_[ret_group].GetOrAllocParam(new_arg, type);
new_args.push_back(param);
} else {
new_args.push_back(new_arg);
}
}
Array<Expr> new_args = GetNewArguments(call->args, ret_group);

auto new_call = CallNode::make(
call->op, new_args, call->attrs, call->type_args);

if (ret_group->root_ref == call) {
// This is the root of the group
// create the new call node.
const GroupInfo& ginfo = ginfo_[ret_group];
auto func = FunctionNode::make(
ginfo.params, new_call, call->checked_type(), {});
func = FunctionSetAttr(func, "Primitive", tvm::Integer(1));
return CallNode::make(func, ginfo.arguments, Attrs());
return MakeNewFunction(ret_group, call->checked_type(), new_call);
} else {
// This is an intermediate node of a fused function
// simply return the new call.
Expand All @@ -747,6 +733,51 @@ class FuseMutator : private ExprMutator {
return ExprMutator::VisitExpr_(call);
}
}

Expr VisitExpr_(const TupleNode* tuple) {
auto* ret_group = gmap_.at(tuple)->FindRoot();
Array<Expr> new_fields = GetNewArguments(tuple->fields, ret_group);
Tuple new_tuple = TupleNode::make(new_fields);
if (ret_group == gmap_.at(tuple)) {
bool isolated = true;
for (size_t i = 0; i < new_fields.size(); ++i) {
isolated &= (new_fields[i].same_as(ginfo_[ret_group].params[i]));
}
if (isolated) {
// Do not put a isolated tuple into a function
return ExprMutator::VisitExpr_(tuple);
}
// This tuple has been fused with other ops before it
return MakeNewFunction(ret_group, tuple->checked_type(), new_tuple);
}
// This tuple is an intermediate node in the group
return new_tuple;
}

Expr MakeNewFunction(GraphPartitioner::Group* group, Type ret_type, Expr body) {
const GroupInfo& ginfo = ginfo_[group];
auto func = FunctionNode::make(ginfo.params, body, ret_type, {});
func = FunctionSetAttr(func, "Primitive", tvm::Integer(1));
return CallNode::make(func, ginfo.arguments, Attrs());
}

Array<Expr> GetNewArguments(const tvm::Array<Expr>& args,
GraphPartitioner::Group* current_group) {
Array<Expr> new_args;
for (auto arg : args) {
auto* arg_group = gmap_.at(arg.get())->FindRoot();
auto type = arg->checked_type();
Expr new_arg = this->Mutate(arg);
if (current_group != arg_group) {
Var param = ginfo_[current_group].GetOrAllocParam(new_arg, type);
new_args.push_back(param);
} else {
new_args.push_back(new_arg);
}
}
return new_args;
}

// Debug function, dump the group assignment in text.
void DebugDumpGroup(const Expr& body) {
std::string text = RelayPrint(body, false, [this](const Expr& expr) -> std::string {
Expand Down
81 changes: 79 additions & 2 deletions tests/python/relay/test_pass_fuse_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,6 @@ def expected():
assert relay.ir_pass.alpha_equal(zz, after)




def test_conv2d_fuse():
"""Test fusion case of conv2d"""
def before(dshape):
Expand Down Expand Up @@ -106,7 +104,86 @@ def expected(dshape):
assert relay.ir_pass.alpha_equal(zz, after)


def test_concatenate():
"""Test fusion case involving concat op and Tuple node"""

def before(dshape):
x = relay.var("x", shape=dshape)
pooled = relay.nn.max_pool2d(x, pool_size=(2, 2), strides=(2, 2), padding=(0, 0))
upsampled = relay.nn.upsampling(pooled, scale=2, layout="NCHW")
concat = relay.concatenate((upsampled, x), axis=1)
out = relay.add(concat, relay.const(1, "float32"))
return relay.Function(relay.ir_pass.free_vars(out), out)

def expected(dshape):
x = relay.var("x", shape=dshape)
pooled = relay.nn.max_pool2d(x, pool_size=(2, 2), strides=(2, 2), padding=(0, 0))
f0 = relay.Function([x], pooled)

p0 = relay.var("p0", shape=(dshape[0], dshape[1], dshape[2]//2, dshape[3]//2))
p1 = relay.var("p1", shape=dshape)
upsampled = relay.nn.upsampling(p0, scale=2, layout="NCHW")
concat = relay.concatenate((upsampled, p1), axis=1)
out = relay.add(concat, relay.const(1, "float32"))
f1 = relay.Function([p0, p1], out)

x = relay.var("x", shape=dshape)
y = relay.Call(f0, [x])
z = relay.Call(f1, [y, x])
return relay.Function([x], z)

dshape = (1, 16, 64, 64)
z = before(dshape)
z = relay.ir_pass.infer_type(z)
zz = relay.ir_pass.fuse_ops(z, opt_level=0)
assert not relay.ir_pass.free_vars(zz)
zz = relay.ir_pass.fuse_ops(z, opt_level=2)
zz = relay.ir_pass.infer_type(zz)
assert not relay.ir_pass.free_vars(zz)
after = relay.ir_pass.infer_type(expected(dshape))
assert relay.ir_pass.alpha_equal(zz, after)


def test_tuple_root():
"""Test fusion case where Tuple node is the root in its group"""

def before(dshape):
x = relay.var("x", shape=dshape)
pooled = relay.nn.max_pool2d(x, pool_size=(2, 2), strides=(2, 2), padding=(0, 0))
upsampled = relay.nn.upsampling(pooled, scale=2, layout="NCHW")
out = relay.Tuple((upsampled, x))
return relay.Function(relay.ir_pass.free_vars(out), out)

def expected(dshape):
x = relay.var("x", shape=dshape)
pooled = relay.nn.max_pool2d(x, pool_size=(2, 2), strides=(2, 2), padding=(0, 0))
f0 = relay.Function([x], pooled)

p0 = relay.var("p0", shape=(dshape[0], dshape[1], dshape[2]//2, dshape[3]//2))
p1 = relay.var("p1", shape=(dshape[0], dshape[1], dshape[2], dshape[3]))
upsampled = relay.nn.upsampling(p0, scale=2, layout="NCHW")
out = relay.Tuple((upsampled, p1))
f1 = relay.Function([p0, p1], out)

x = relay.var("x", shape=dshape)
y = relay.Call(f0, [x])
z = relay.Call(f1, [y, x])
return relay.Function([x], z)

dshape = (1, 16, 64, 64)
z = before(dshape)
z = relay.ir_pass.infer_type(z)
zz = relay.ir_pass.fuse_ops(z, opt_level=0)
assert not relay.ir_pass.free_vars(zz)
zz = relay.ir_pass.fuse_ops(z, opt_level=2)
zz = relay.ir_pass.infer_type(zz)
assert not relay.ir_pass.free_vars(zz)
after = relay.ir_pass.infer_type(expected(dshape))
assert relay.ir_pass.alpha_equal(zz, after)


if __name__ == "__main__":
test_fuse_simple()
test_conv2d_fuse()
test_concatenate()
test_tuple_root()