diff --git a/src/relay/pass/fuse_ops.cc b/src/relay/pass/fuse_ops.cc index b9e0823e88fa..21660decf2fa 100644 --- a/src/relay/pass/fuse_ops.cc +++ b/src/relay/pass/fuse_ops.cc @@ -232,8 +232,11 @@ class IndexedForwardGraph::Creator : private ExprVisitor { } void VisitExpr_(const TupleNode* op) { + CHECK(graph_.node_map.count(op)); + Node* tuple_node = graph_.node_map.at(op); + tuple_node->pattern = kInjective; for (const Expr& field : op->fields) { - this->Update(field, nullptr, kOpaque); + this->Update(field, tuple_node, kInjective); } ExprVisitor::VisitExpr_(op); this->AddNode(op); @@ -712,32 +715,15 @@ class FuseMutator : private ExprMutator { // then we must have a group assignment for it already. CHECK(gmap_.count(call)); auto* ret_group = gmap_.at(call)->FindRoot(); - Array new_args; - for (auto arg : call->args) { - auto type = arg->checked_type(); - CHECK(gmap_.count(arg.get())) - << "cannot find group of " << arg; - auto* arg_group = gmap_.at(arg.get())->FindRoot(); - Expr new_arg = this->Mutate(arg); - - if (ret_group != arg_group) { - Var param = ginfo_[ret_group].GetOrAllocParam(new_arg, type); - new_args.push_back(param); - } else { - new_args.push_back(new_arg); - } - } + Array new_args = GetNewArguments(call->args, ret_group); + auto new_call = CallNode::make( call->op, new_args, call->attrs, call->type_args); if (ret_group->root_ref == call) { // This is the root of the group // create the new call node. - const GroupInfo& ginfo = ginfo_[ret_group]; - auto func = FunctionNode::make( - ginfo.params, new_call, call->checked_type(), {}); - func = FunctionSetAttr(func, "Primitive", tvm::Integer(1)); - return CallNode::make(func, ginfo.arguments, Attrs()); + return MakeNewFunction(ret_group, call->checked_type(), new_call); } else { // This is an intermediate node of a fused function // simply return the new call. @@ -747,6 +733,51 @@ class FuseMutator : private ExprMutator { return ExprMutator::VisitExpr_(call); } } + + Expr VisitExpr_(const TupleNode* tuple) { + auto* ret_group = gmap_.at(tuple)->FindRoot(); + Array new_fields = GetNewArguments(tuple->fields, ret_group); + Tuple new_tuple = TupleNode::make(new_fields); + if (ret_group == gmap_.at(tuple)) { + bool isolated = true; + for (size_t i = 0; i < new_fields.size(); ++i) { + isolated &= (new_fields[i].same_as(ginfo_[ret_group].params[i])); + } + if (isolated) { + // Do not put a isolated tuple into a function + return ExprMutator::VisitExpr_(tuple); + } + // This tuple has been fused with other ops before it + return MakeNewFunction(ret_group, tuple->checked_type(), new_tuple); + } + // This tuple is an intermediate node in the group + return new_tuple; + } + + Expr MakeNewFunction(GraphPartitioner::Group* group, Type ret_type, Expr body) { + const GroupInfo& ginfo = ginfo_[group]; + auto func = FunctionNode::make(ginfo.params, body, ret_type, {}); + func = FunctionSetAttr(func, "Primitive", tvm::Integer(1)); + return CallNode::make(func, ginfo.arguments, Attrs()); + } + + Array GetNewArguments(const tvm::Array& args, + GraphPartitioner::Group* current_group) { + Array new_args; + for (auto arg : args) { + auto* arg_group = gmap_.at(arg.get())->FindRoot(); + auto type = arg->checked_type(); + Expr new_arg = this->Mutate(arg); + if (current_group != arg_group) { + Var param = ginfo_[current_group].GetOrAllocParam(new_arg, type); + new_args.push_back(param); + } else { + new_args.push_back(new_arg); + } + } + return new_args; + } + // Debug function, dump the group assignment in text. void DebugDumpGroup(const Expr& body) { std::string text = RelayPrint(body, false, [this](const Expr& expr) -> std::string { diff --git a/tests/python/relay/test_pass_fuse_ops.py b/tests/python/relay/test_pass_fuse_ops.py index 27806791c399..28ea8dd28988 100644 --- a/tests/python/relay/test_pass_fuse_ops.py +++ b/tests/python/relay/test_pass_fuse_ops.py @@ -28,8 +28,6 @@ def expected(): assert relay.ir_pass.alpha_equal(zz, after) - - def test_conv2d_fuse(): """Test fusion case of conv2d""" def before(dshape): @@ -106,7 +104,86 @@ def expected(dshape): assert relay.ir_pass.alpha_equal(zz, after) +def test_concatenate(): + """Test fusion case involving concat op and Tuple node""" + + def before(dshape): + x = relay.var("x", shape=dshape) + pooled = relay.nn.max_pool2d(x, pool_size=(2, 2), strides=(2, 2), padding=(0, 0)) + upsampled = relay.nn.upsampling(pooled, scale=2, layout="NCHW") + concat = relay.concatenate((upsampled, x), axis=1) + out = relay.add(concat, relay.const(1, "float32")) + return relay.Function(relay.ir_pass.free_vars(out), out) + + def expected(dshape): + x = relay.var("x", shape=dshape) + pooled = relay.nn.max_pool2d(x, pool_size=(2, 2), strides=(2, 2), padding=(0, 0)) + f0 = relay.Function([x], pooled) + + p0 = relay.var("p0", shape=(dshape[0], dshape[1], dshape[2]//2, dshape[3]//2)) + p1 = relay.var("p1", shape=dshape) + upsampled = relay.nn.upsampling(p0, scale=2, layout="NCHW") + concat = relay.concatenate((upsampled, p1), axis=1) + out = relay.add(concat, relay.const(1, "float32")) + f1 = relay.Function([p0, p1], out) + + x = relay.var("x", shape=dshape) + y = relay.Call(f0, [x]) + z = relay.Call(f1, [y, x]) + return relay.Function([x], z) + + dshape = (1, 16, 64, 64) + z = before(dshape) + z = relay.ir_pass.infer_type(z) + zz = relay.ir_pass.fuse_ops(z, opt_level=0) + assert not relay.ir_pass.free_vars(zz) + zz = relay.ir_pass.fuse_ops(z, opt_level=2) + zz = relay.ir_pass.infer_type(zz) + assert not relay.ir_pass.free_vars(zz) + after = relay.ir_pass.infer_type(expected(dshape)) + assert relay.ir_pass.alpha_equal(zz, after) + + +def test_tuple_root(): + """Test fusion case where Tuple node is the root in its group""" + + def before(dshape): + x = relay.var("x", shape=dshape) + pooled = relay.nn.max_pool2d(x, pool_size=(2, 2), strides=(2, 2), padding=(0, 0)) + upsampled = relay.nn.upsampling(pooled, scale=2, layout="NCHW") + out = relay.Tuple((upsampled, x)) + return relay.Function(relay.ir_pass.free_vars(out), out) + + def expected(dshape): + x = relay.var("x", shape=dshape) + pooled = relay.nn.max_pool2d(x, pool_size=(2, 2), strides=(2, 2), padding=(0, 0)) + f0 = relay.Function([x], pooled) + + p0 = relay.var("p0", shape=(dshape[0], dshape[1], dshape[2]//2, dshape[3]//2)) + p1 = relay.var("p1", shape=(dshape[0], dshape[1], dshape[2], dshape[3])) + upsampled = relay.nn.upsampling(p0, scale=2, layout="NCHW") + out = relay.Tuple((upsampled, p1)) + f1 = relay.Function([p0, p1], out) + + x = relay.var("x", shape=dshape) + y = relay.Call(f0, [x]) + z = relay.Call(f1, [y, x]) + return relay.Function([x], z) + + dshape = (1, 16, 64, 64) + z = before(dshape) + z = relay.ir_pass.infer_type(z) + zz = relay.ir_pass.fuse_ops(z, opt_level=0) + assert not relay.ir_pass.free_vars(zz) + zz = relay.ir_pass.fuse_ops(z, opt_level=2) + zz = relay.ir_pass.infer_type(zz) + assert not relay.ir_pass.free_vars(zz) + after = relay.ir_pass.infer_type(expected(dshape)) + assert relay.ir_pass.alpha_equal(zz, after) + if __name__ == "__main__": test_fuse_simple() test_conv2d_fuse() + test_concatenate() + test_tuple_root()