diff --git a/include/tvm/relay/op_attr_types.h b/include/tvm/relay/op_attr_types.h index 464bc1cc0b64..ca7f6e5d3908 100644 --- a/include/tvm/relay/op_attr_types.h +++ b/include/tvm/relay/op_attr_types.h @@ -49,6 +49,9 @@ enum OpPatternKind { // Complex operation, can still fuse elemwise operations into its output. // but cannot chain another complex op kOutEWiseFusable = 4, + // The pattern for tuple nodes. Can fuse into subsequent injective ops, + // but treated specially + kTuple = 7, // Opaque operation, cannot fuse anything. kOpaque = 8 }; diff --git a/python/tvm/relay/op/op.py b/python/tvm/relay/op/op.py index 6312f023df0d..6ba207934d1b 100644 --- a/python/tvm/relay/op/op.py +++ b/python/tvm/relay/op/op.py @@ -112,6 +112,8 @@ class OpPattern(object): COMM_REDUCE = 3 # Complex op, can still fuse ewise into it OUT_ELEMWISE_FUSABLE = 4 + # Represents tuple node + TUPLE = 7 # Not fusable opaque op OPAQUE = 8 diff --git a/src/relay/pass/fuse_ops.cc b/src/relay/pass/fuse_ops.cc index 12e3174dcade..55d609872929 100644 --- a/src/relay/pass/fuse_ops.cc +++ b/src/relay/pass/fuse_ops.cc @@ -267,7 +267,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor { void VisitExpr_(const TupleNode* op) final { CHECK(graph_.node_map.count(op)); Node* tuple_node = graph_.node_map.at(op); - tuple_node->pattern = kInjective; + tuple_node->pattern = kTuple; for (const Expr& field : op->fields) { if (field->checked_type().as()) { this->Update(field, tuple_node, kInjective); @@ -661,12 +661,36 @@ class GraphPartitioner { // no actions needed if the current node have no dominator if (dom_node->parent == nullptr) continue; CHECK(!graph_node->extern_ref); - // Skip if current node is already fused to the parent. size_t dom_parent_gindex = dom_node->parent->gnode->index; + + if (phase == 2) { + // Fuse injective ops into intermediate tuples, if any + if (group_node->pattern > kInjective) continue; + Group* dom_parent_group = groups_[dom_parent_gindex]; + Group* dom_root_group = dom_parent_group->FindRoot(); + // If dom node group has a tuple as its root, we do not fuse tuple fields into it + if (dom_root_group->pattern == kTuple) continue; + if (dom_parent_group->pattern == kTuple && dom_root_group->pattern <= kInjective) { + // Now we know the tuple has been fused into subsequent injective ops + auto fcond = [](OpPatternKind kind, bool is_sink) { + return kind <= kInjective; + }; + // dom_root_group can also be tuple, as in inception layers + // CheckPath is needed to avoid fusing two intermediate tuples + if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) { + CommitFuse(graph_node, dom_node->parent->gnode); + } + } + continue; + } + + // Skip if current node is already fused to the parent. if (groups_[dom_parent_gindex] != nullptr && group_node->FindRoot() == groups_[dom_parent_gindex]->FindRoot()) { continue; } + // Do not fuse into tuple for now + if (groups_[dom_parent_gindex]->pattern == kTuple) continue; // Try to fuse current node to its post-dominator. if (group_node->pattern == kOutEWiseFusable) { if (phase != 0) continue; @@ -702,7 +726,7 @@ class GraphPartitioner { CommitFuse(graph_node, dom_node->parent->gnode); } } - } else if (group_node->pattern == kInjective) { + } else if (group_node->pattern == kInjective || group_node->pattern == kTuple) { // defer injective fusion to second phase. // so conv2d always finishes fusing. if (phase != 1) continue; @@ -728,7 +752,7 @@ GraphPartitioner::Partition(const IndexedForwardGraph& graph) { // get post dominator tree auto post_dom_tree = DominatorTree::PostDom(arena_, graph); // run fusion algorithm. - for (int phase = 0; phase < 2; ++phase) { + for (int phase = 0; phase < 3; ++phase) { this->RunFuse(graph, post_dom_tree, phase); } return std::move(groups_); @@ -821,29 +845,11 @@ class FuseMutator : private ExprMutator { Expr VisitExpr_(const TupleNode* tuple) { auto* ret_group = gmap_.at(tuple)->FindRoot(); - Array new_fields = GetNewArguments(tuple->fields, ret_group); if (ret_group == gmap_.at(tuple)) { - // This tuple is the root of its group. Check if all fields come from other groups. - bool isolated = new_fields.size() == ginfo_[ret_group].params.size(); - for (size_t i = 0; i < new_fields.size() && isolated; ++i) { - isolated &= (new_fields[i].same_as(ginfo_[ret_group].params[i])); - } - if (isolated) { - // Do not put a isolated tuple into a function - return ExprMutator::VisitExpr_(tuple); - } - // This tuple has been fused with other ops before it - for (size_t i = 0; i < new_fields.size(); i++) { - // Copy function arguments to tuple field of the output because currently graph memory - // planer doesn't support inplace operations - if (new_fields[i].as()) { - auto copy = Copy(new_fields[i]); - new_fields.Set(i, copy); - } - } - return MakeNewFunction(ret_group, tuple->checked_type(), TupleNode::make(new_fields)); + return ExprMutator::VisitExpr_(tuple); } // This tuple is an intermediate node in the group + Array new_fields = GetNewArguments(tuple->fields, ret_group); return TupleNode::make(new_fields); } diff --git a/tests/python/relay/test_backend_compile_engine.py b/tests/python/relay/test_backend_compile_engine.py index 3b479b847619..ca4619c97886 100644 --- a/tests/python/relay/test_backend_compile_engine.py +++ b/tests/python/relay/test_backend_compile_engine.py @@ -69,8 +69,16 @@ def test_compile_injective_with_tuple(): relay.build(func, 'llvm') +def test_compile_tuple_dup(): + x = relay.var("data", shape=(16, 16)) + log = relay.log(x) + output = relay.Tuple([log, log]) + f = relay.Function([x], output) + relay.build(f, 'llvm') + + if __name__ == "__main__": test_compile_engine() test_compile_placeholder_bypass() test_compile_injective_with_tuple() - + test_compile_tuple_dup() diff --git a/tests/python/relay/test_pass_fuse_ops.py b/tests/python/relay/test_pass_fuse_ops.py index baafbeebd560..bdffdf7c129f 100644 --- a/tests/python/relay/test_pass_fuse_ops.py +++ b/tests/python/relay/test_pass_fuse_ops.py @@ -176,16 +176,14 @@ def expected(dshape): f0 = relay.Function([x], pooled) p0 = relay.var("p0", shape=(dshape[0], dshape[1], dshape[2]//2, dshape[3]//2)) - p1 = relay.var("p1", shape=(dshape[0], dshape[1], dshape[2], dshape[3])) - p1_copy = relay.copy(p1) upsampled = relay.nn.upsampling(p0, scale=2, layout="NCHW") - out = relay.Tuple((upsampled, p1_copy)) - f1 = relay.Function([p0, p1], out) + f1 = relay.Function([p0], upsampled) x = relay.var("x", shape=dshape) y = relay.Call(f0, [x]) - z = relay.Call(f1, [y, x]) - return relay.Function([x], z) + z = relay.Call(f1, [y]) + tup = relay.Tuple((z, x)) + return relay.Function([x], tup) dshape = (1, 16, 64, 64) z = before(dshape) @@ -199,41 +197,6 @@ def expected(dshape): assert relay.ir_pass.alpha_equal(zz, after) -def test_tuple_strided_slice(): - """ - Test fusion case where the number of fields of tuple and - the number of parameters to the function containing the tuple are different - """ - - def before(dshape): - x = relay.var("x", shape=dshape) - slice1 = relay.strided_slice(x, begin=[0, 0], end=[dshape[1]//2, dshape[1]], strides=[1,1]) - slice2 = relay.strided_slice(x, begin=[dshape[1]//2, 0], end=[dshape[0], dshape[1]], strides=[1,1]) - out = relay.Tuple((slice1, slice2)) - return relay.Function([x], out) - - def expected(dshape): - x = relay.var("x", shape=dshape) - slice1 = relay.strided_slice(x, begin=[0, 0], end=[dshape[1]//2, dshape[1]], strides=[1,1]) - slice2 = relay.strided_slice(x, begin=[dshape[1]//2, 0], end=[dshape[0], dshape[1]], strides=[1,1]) - out = relay.Tuple((slice1, slice2)) - f0 = relay.Function([x], out) - - x = relay.var("x", shape=dshape) - y = relay.Call(f0, [x]) - return relay.Function([x], y) - - dshape = (64, 64) - z = before(dshape) - z = relay.ir_pass.infer_type(z) - zz = relay.ir_pass.fuse_ops(z, opt_level=0) - assert not relay.ir_pass.free_vars(zz) - zz = relay.ir_pass.fuse_ops(z, opt_level=2) - zz = relay.ir_pass.infer_type(zz) - assert not relay.ir_pass.free_vars(zz) - after = relay.ir_pass.infer_type(expected(dshape)) - assert relay.ir_pass.alpha_equal(zz, after) - def test_stop_fusion(): def before(dshape): @@ -377,13 +340,178 @@ def expected(dim): assert relay.ir_pass.alpha_equal(zz, after) +def test_tuple_intermediate(): + def before(x): + inj = relay.squeeze(x) + y1 = relay.add(inj, relay.const(1, "float32")) + tmp = relay.squeeze(inj) + tmp = relay.add(tmp, relay.const(1, "float32")) + y2 = relay.add(tmp, relay.const(1, "float32")) + y3 = relay.add(inj, relay.const(1, "float32")) + concat = relay.concatenate((y1, y2, y3), axis=1) + out_inj = relay.squeeze(concat) + out = relay.add(out_inj, relay.const(1, "float32")) + return relay.Function(relay.ir_pass.free_vars(out), out) + + def expected(p0): + f0 = before(p0) + x = relay.var("x", shape=dshape) + y = relay.Call(f0, [x]) + return relay.Function([x], y) + + dshape = (1, 16, 64, 64) + x = relay.var("x", shape=dshape) + z = before(x) + z = relay.ir_pass.infer_type(z) + zz = relay.ir_pass.fuse_ops(z, opt_level=0) + assert not relay.ir_pass.free_vars(zz) + zz = relay.ir_pass.fuse_ops(z, opt_level=2) + relay.build(zz, 'llvm') + zz = relay.ir_pass.infer_type(zz) + assert not relay.ir_pass.free_vars(zz) + after = relay.ir_pass.infer_type(expected(x)) + assert relay.ir_pass.alpha_equal(zz, after) + + +def test_tuple_consecutive(): + def gen_intermediate_tuple(x): + y1 = relay.add(x, relay.const(1, "float32")) + y2 = relay.add(x, relay.const(1, "float32")) + y3 = relay.add(x, relay.const(1, "float32")) + concat = relay.concatenate((y1, y2, y3), axis=1) + out = relay.add(concat, relay.const(1, "float32")) + return out + + def gen_consecutive_tuple(x): + y1 = gen_intermediate_tuple(x) + y2 = gen_intermediate_tuple(x) + y3 = gen_intermediate_tuple(x) + concat = relay.concatenate((y1, y2, y3), axis=1) + return concat + + def before(x): + concat = gen_consecutive_tuple(x) + pooled = relay.nn.max_pool2d(concat, pool_size=(2, 2), strides=(2, 2), padding=(0, 0)) + out = relay.add(pooled, relay.const(1, "float32")) + out2 = relay.add(out, relay.const(1, "float32")) + out_tup = relay.Tuple((out, out2)) + return relay.Function(relay.ir_pass.free_vars(out_tup), out_tup) + + def expected(dshape): + p0 = relay.var("p0", shape=dshape) + concat = gen_consecutive_tuple(p0) + f0 = relay.Function([p0], concat) + + p01 = relay.var("p01", shape=(1, dshape[1]*9, dshape[2], dshape[3])) + pooled = relay.nn.max_pool2d(p01, pool_size=(2, 2), strides=(2, 2), padding=(0, 0)) + out = relay.add(pooled, relay.const(1, "float32")) + f1 = relay.Function([p01], out) + + p02 = relay.var("p02", shape=(1, dshape[1]*9, dshape[2]//2, dshape[3]//2)) + out = relay.add(p02, relay.const(1, "float32")) + f2 = relay.Function([p02], out) + + x = relay.var("x", shape=dshape) + y = relay.Call(f0, [x]) + z = relay.Call(f1, [y]) + z2 = relay.Call(f2, [z]) + + return relay.Function([x], relay.Tuple((z, z2))) + + dshape = (1, 16, 64, 64) + x = relay.var("x", shape=dshape) + z = before(x) + z = relay.ir_pass.infer_type(z) + zz = relay.ir_pass.fuse_ops(z, opt_level=0) + assert not relay.ir_pass.free_vars(zz) + zz = relay.ir_pass.fuse_ops(z, opt_level=2) + relay.build(zz, 'llvm') + zz = relay.ir_pass.infer_type(zz) + assert not relay.ir_pass.free_vars(zz) + after = relay.ir_pass.infer_type(expected(dshape)) + assert relay.ir_pass.alpha_equal(zz, after) + + +def test_inception_like(): + def conv(data): + y = relay.nn.conv2d(data, relay.var("w"), + kernel_size=(3, 3), + padding=(1, 1), + channels=16) + return relay.nn.relu(data=y) + + def inception_like(data): + c0 = conv(data) + c1 = conv(data) + return relay.concatenate((c0, c1), axis=1) + + def before(dshape): + x = relay.var("x", shape=dshape) + in1 = inception_like(x) + in2 = inception_like(in1) + return relay.Function(relay.ir_pass.free_vars(in2), in2) + + def expected(dshape): + p0 = relay.var("p0", shape=dshape) + c = conv(p0) + f0 = relay.Function(relay.ir_pass.free_vars(c), c) + + p01 = relay.var("p01", shape=dshape) + c = conv(p01) + f1 = relay.Function(relay.ir_pass.free_vars(c), c) + + p02 = relay.var("p02", shape=dshape) + p12 = relay.var("p12", shape=dshape) + concat1 = relay.concatenate((p02, p12), axis=1) + f_concat1 = relay.Function([p02, p12], concat1) + + dshape2 = (dshape[0], dshape[1]*2, dshape[2], dshape[3]) + + p03 = relay.var("p03", shape=dshape2) + c = conv(p03) + f2 = relay.Function(relay.ir_pass.free_vars(c), c) + + p04 = relay.var("p04", shape=dshape2) + c = conv(p04) + f3 = relay.Function(relay.ir_pass.free_vars(c), c) + + p05 = relay.var("p05", shape=dshape) + p15 = relay.var("p15", shape=dshape) + concat2 = relay.concatenate((p05, p15), axis=1) + f_concat2 = relay.Function([p05, p15], concat2) + + x = relay.var("x", shape=dshape) + c1 = relay.Call(f0, [x, relay.var("w1")]) + c2 = relay.Call(f1, [x, relay.var("w2")]) + concat = relay.Call(f_concat1, [c1, c2]) + c3 = relay.Call(f2, [concat, relay.var("w3")]) + c4 = relay.Call(f3, [concat, relay.var("w4")]) + out = relay.Call(f_concat2, [c3, c4]) + + return relay.Function(relay.ir_pass.free_vars(out), out) + + dshape = (1, 16, 64, 64) + z = before(dshape) + z = relay.ir_pass.infer_type(z) + zz = relay.ir_pass.fuse_ops(z, opt_level=0) + assert not relay.ir_pass.free_vars(zz) + zz = relay.ir_pass.fuse_ops(z, opt_level=2) + relay.build(zz, 'llvm') + zz = relay.ir_pass.infer_type(zz) + assert not relay.ir_pass.free_vars(zz) + after = relay.ir_pass.infer_type(expected(dshape)) + assert relay.ir_pass.alpha_equal(zz, after) + + if __name__ == "__main__": test_fuse_simple() test_conv2d_fuse() test_concatenate() test_tuple_root() - test_tuple_strided_slice() test_stop_fusion() test_fuse_myia_regression() test_fuse_tuple_get_elemwise() test_tuple_get_root() + test_tuple_intermediate() + test_tuple_consecutive() + test_inception_like()