diff --git a/src/relay/pass/fuse_ops.cc b/src/relay/pass/fuse_ops.cc index 5585eb8ae6c7..fad4fb781b5a 100644 --- a/src/relay/pass/fuse_ops.cc +++ b/src/relay/pass/fuse_ops.cc @@ -740,8 +740,9 @@ class FuseMutator : private ExprMutator { Array new_fields = GetNewArguments(tuple->fields, ret_group); Tuple new_tuple = TupleNode::make(new_fields); if (ret_group == gmap_.at(tuple)) { - bool isolated = true; - for (size_t i = 0; i < new_fields.size(); ++i) { + // This tuple is the root of its group. Check if all fields come from other groups. + bool isolated = new_fields.size() == ginfo_[ret_group].params.size(); + for (size_t i = 0; i < new_fields.size() && isolated; ++i) { isolated &= (new_fields[i].same_as(ginfo_[ret_group].params[i])); } if (isolated) { diff --git a/tests/python/relay/test_pass_fuse_ops.py b/tests/python/relay/test_pass_fuse_ops.py index 28ea8dd28988..37019a9575e6 100644 --- a/tests/python/relay/test_pass_fuse_ops.py +++ b/tests/python/relay/test_pass_fuse_ops.py @@ -182,8 +182,46 @@ def expected(dshape): assert relay.ir_pass.alpha_equal(zz, after) +def test_tuple_strided_slice(): + """ + Test fusion case where the number of fields of tuple and + the number of parameters to the function containing the tuple are different + """ + + def before(dshape): + x = relay.var("x", shape=dshape) + slice1 = relay.strided_slice(x, begin=[0, 0], end=[dshape[1]//2, dshape[1]], strides=[1,1]) + slice2 = relay.strided_slice(x, begin=[dshape[1]//2, 0], end=[dshape[0], dshape[1]], strides=[1,1]) + out = relay.Tuple((slice1, slice2)) + return relay.Function([x], out) + + def expected(dshape): + x = relay.var("x", shape=dshape) + slice1 = relay.strided_slice(x, begin=[0, 0], end=[dshape[1]//2, dshape[1]], strides=[1,1]) + slice2 = relay.strided_slice(x, begin=[dshape[1]//2, 0], end=[dshape[0], dshape[1]], strides=[1,1]) + out = relay.Tuple((slice1, slice2)) + f0 = relay.Function([x], out) + + x = relay.var("x", shape=dshape) + y = relay.Call(f0, [x]) + return relay.Function([x], y) + + dshape = (64, 64) + z = before(dshape) + z = relay.ir_pass.infer_type(z) + zz = relay.ir_pass.fuse_ops(z, opt_level=0) + assert not relay.ir_pass.free_vars(zz) + zz = relay.ir_pass.fuse_ops(z, opt_level=2) + zz = relay.ir_pass.infer_type(zz) + assert not relay.ir_pass.free_vars(zz) + after = relay.ir_pass.infer_type(expected(dshape)) + assert relay.ir_pass.alpha_equal(zz, after) + print(zz.astext()) + + if __name__ == "__main__": test_fuse_simple() test_conv2d_fuse() test_concatenate() test_tuple_root() + test_tuple_strided_slice()