diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index bdfd8f78b22e..faed052a034b 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -1546,6 +1546,27 @@ def flatten(self, inputs, input_types): out = _op.squeeze(out, axis=squeeze_axes) return out + def unflatten(self, inputs, input_types): + data = inputs[0] + dim = int(inputs[1]) + unflattened_size = tuple(inputs[2]) + dshape = get_const_tuple(self.infer_shape_with_prelude(data)) + + dim = dim if dim >= 0 else len(dshape) + dim + assert len(dshape) > dim >= 0 + + assert unflattened_size.count(-1) <= 1 + + mult = np.multiply.reduce(unflattened_size) + if mult < 0: + assert dshape[dim] % mult == 0 + else: + assert dshape[dim] == mult + + new_shape = dshape[:dim] + unflattened_size + dshape[dim + 1 :] + out = _op.reshape(data, new_shape) + return out + def addmm(self, inputs, input_types): input_mat = inputs[0] mat1 = inputs[1] @@ -3945,6 +3966,7 @@ def create_convert_map(self): "aten::t": self.transpose, "aten::numpy_T": self.numpy_T, "aten::flatten": self.flatten, + "aten::unflatten": self.unflatten, "aten::addmm": self.addmm, "aten::size": self.size, "aten::view": self.view, diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 894bea60ed46..2f346feced88 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -1544,6 +1544,40 @@ def _test_flatten(start_dim, end_dim): verify_model(_test_flatten(-3, -2), inp) +@tvm.testing.uses_gpu +def test_unflatten(): + """test_unflatten""" + + def _test_unflatten(dim, unflattened_size): + return lambda inp: torch.unflatten(inp, dim, unflattened_size) + + inp = torch.rand(60) + + # [60] -> [3, 5, 2, 2] + verify_model(_test_unflatten(0, (3, 5, 2, 2)), inp) + verify_model(_test_unflatten(0, (-1, 5, 2, 2)), inp) + verify_model(_test_unflatten(0, (3, -1, 2, 2)), inp) + verify_model(_test_unflatten(0, (3, 5, -1, 2)), inp) + verify_model(_test_unflatten(0, (3, 5, 2, -1)), inp) + verify_model(_test_unflatten(-1, (3, 5, 2, 2)), inp) + verify_model(_test_unflatten(-1, (-1, 5, 2, 2)), inp) + verify_model(_test_unflatten(-1, (3, -1, 2, 2)), inp) + verify_model(_test_unflatten(-1, (3, 5, -1, 2)), inp) + verify_model(_test_unflatten(-1, (3, 5, 2, -1)), inp) + + inp = torch.rand(3, 4, 1) + + # [3, 4, 1] -> [3, 2, 2, 1] + verify_model(_test_unflatten(1, (2, 2)), inp) + verify_model(_test_unflatten(1, (-1, 2)), inp) + + inp = torch.rand(5, 12, 3) + + # [5, 12, 3] -> [5, 2, 2, 3, 1, 1, 3] + verify_model(_test_unflatten(1, (2, 2, 3, 1, 1)), inp) + verify_model(_test_unflatten(-2, (2, 2, 3, 1, 1)), inp) + + @tvm.testing.uses_gpu def test_forward_transpose(): """test_forward_transpose""" @@ -4744,7 +4778,7 @@ def test_fn(x, mask): verify_model(test_fn, [inp.to(torch.float64), inp > 0.5]) -@pytest.mark.skip(reason="unsupported op: 'aten::scaled_dot_product_attention', 'aten::unflatten'") +@pytest.mark.skip(reason="unsupported op: 'aten::scaled_dot_product_attention'") def test_transformer(): """test_transformer""" model = torch.nn.Transformer(d_model=256, nhead=8, num_encoder_layers=6, num_decoder_layers=6)