From bca8a840aca3bd04b81e3f38c9b1cde6fb88ec84 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Thu, 16 Nov 2023 20:35:20 +0900 Subject: [PATCH 1/5] add support for `aten::unflatten` --- python/tvm/relay/frontend/pytorch.py | 11 +++++++ tests/python/frontend/pytorch/test_forward.py | 30 ++++++++++++++++++- 2 files changed, 40 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index bdfd8f78b22e..5cd539f66a3b 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -1546,6 +1546,16 @@ def flatten(self, inputs, input_types): out = _op.squeeze(out, axis=squeeze_axes) return out + def unflatten(self, inputs, input_types): + data = inputs[0] + dim = int(inputs[1]) + unflattened_size = tuple(inputs[2]) + dshape = get_const_tuple(self.infer_shape_with_prelude(data)) + assert len(dshape) > dim + new_shape = dshape[:dim] + unflattened_size + dshape[dim + 1 :] + out = _op.reshape(data, new_shape) + return out + def addmm(self, inputs, input_types): input_mat = inputs[0] mat1 = inputs[1] @@ -3945,6 +3955,7 @@ def create_convert_map(self): "aten::t": self.transpose, "aten::numpy_T": self.numpy_T, "aten::flatten": self.flatten, + "aten::unflatten": self.unflatten, "aten::addmm": self.addmm, "aten::size": self.size, "aten::view": self.view, diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 894bea60ed46..50d7930c0578 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -1544,6 +1544,34 @@ def _test_flatten(start_dim, end_dim): verify_model(_test_flatten(-3, -2), inp) +@tvm.testing.uses_gpu +def test_unflatten(): + """test_unflatten""" + + def _test_unflatten(dim, unflattened_size): + return lambda inp: torch.unflatten(inp, dim, unflattened_size) + + inp = torch.rand(60) + + # [60] -> [3, 5, 2, 2] + verify_model(_test_unflatten(0, (3, 5, 2, 2)), inp) + verify_model(_test_unflatten(0, (-1, 5, 2, 2)), inp) + verify_model(_test_unflatten(0, (3, -1, 2, 2)), inp) + verify_model(_test_unflatten(0, (3, 5, -1, 2)), inp) + verify_model(_test_unflatten(0, (3, 5, 2, -1)), inp) + + inp = torch.rand(3, 4, 1) + + # [3, 4, 1] -> [3, 2, 2, 1] + verify_model(_test_unflatten(1, (2, 2)), inp) + verify_model(_test_unflatten(1, (-1, 2)), inp) + + inp = torch.rand(5, 12, 3) + + # [5, 12, 3] -> [5, 2, 2, 3, 1, 1, 3] + verify_model(_test_unflatten(-2, (2, 2, 3, 1, 1)), inp) + + @tvm.testing.uses_gpu def test_forward_transpose(): """test_forward_transpose""" @@ -4744,7 +4772,7 @@ def test_fn(x, mask): verify_model(test_fn, [inp.to(torch.float64), inp > 0.5]) -@pytest.mark.skip(reason="unsupported op: 'aten::scaled_dot_product_attention', 'aten::unflatten'") +@pytest.mark.skip(reason="unsupported op: 'aten::scaled_dot_product_attention'") def test_transformer(): """test_transformer""" model = torch.nn.Transformer(d_model=256, nhead=8, num_encoder_layers=6, num_decoder_layers=6) From 252a58bea7aa698e0271375d5d0ecd475b37fee8 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Thu, 16 Nov 2023 20:44:58 +0900 Subject: [PATCH 2/5] Add check that dshape[dim] % multiplication of dimensions in unflattened_size == 0 --- python/tvm/relay/frontend/pytorch.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 5cd539f66a3b..1b36adde6aa9 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -1552,6 +1552,13 @@ def unflatten(self, inputs, input_types): unflattened_size = tuple(inputs[2]) dshape = get_const_tuple(self.infer_shape_with_prelude(data)) assert len(dshape) > dim + + mult = 1 + for s in unflattened_size: + if s is not -1: + mult *= s + assert dshape[dim] % mult == 0 + new_shape = dshape[:dim] + unflattened_size + dshape[dim + 1 :] out = _op.reshape(data, new_shape) return out From 4a8656dfb8b6ffcacc7b1a17853ccdc357c866f6 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Fri, 17 Nov 2023 12:42:08 +0900 Subject: [PATCH 3/5] Update shape check --- python/tvm/relay/frontend/pytorch.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 1b36adde6aa9..ce62d2b833aa 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -1553,11 +1553,13 @@ def unflatten(self, inputs, input_types): dshape = get_const_tuple(self.infer_shape_with_prelude(data)) assert len(dshape) > dim - mult = 1 - for s in unflattened_size: - if s is not -1: - mult *= s - assert dshape[dim] % mult == 0 + assert unflattened_size.count(-1) <= 1 + + mult = np.multiply.reduce(unflattened_size) + if mult < 0: + assert dshape[dim] % mult == 0 + else: + assert dshape[dim] == mult new_shape = dshape[:dim] + unflattened_size + dshape[dim + 1 :] out = _op.reshape(data, new_shape) From d5798f2a29257aefbbdb81994a4ee86b079d16b4 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Fri, 17 Nov 2023 14:53:23 +0900 Subject: [PATCH 4/5] handle `dim=-1` --- python/tvm/relay/frontend/pytorch.py | 4 +++- tests/python/frontend/pytorch/test_forward.py | 6 ++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index ce62d2b833aa..473aae19d1b2 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -1551,7 +1551,9 @@ def unflatten(self, inputs, input_types): dim = int(inputs[1]) unflattened_size = tuple(inputs[2]) dshape = get_const_tuple(self.infer_shape_with_prelude(data)) - assert len(dshape) > dim + + dim = dim if dim >= 0 else len(dshape) + dim + assert len(dshape) > dim and dim >= 0 assert unflattened_size.count(-1) <= 1 diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 50d7930c0578..2f346feced88 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -1559,6 +1559,11 @@ def _test_unflatten(dim, unflattened_size): verify_model(_test_unflatten(0, (3, -1, 2, 2)), inp) verify_model(_test_unflatten(0, (3, 5, -1, 2)), inp) verify_model(_test_unflatten(0, (3, 5, 2, -1)), inp) + verify_model(_test_unflatten(-1, (3, 5, 2, 2)), inp) + verify_model(_test_unflatten(-1, (-1, 5, 2, 2)), inp) + verify_model(_test_unflatten(-1, (3, -1, 2, 2)), inp) + verify_model(_test_unflatten(-1, (3, 5, -1, 2)), inp) + verify_model(_test_unflatten(-1, (3, 5, 2, -1)), inp) inp = torch.rand(3, 4, 1) @@ -1569,6 +1574,7 @@ def _test_unflatten(dim, unflattened_size): inp = torch.rand(5, 12, 3) # [5, 12, 3] -> [5, 2, 2, 3, 1, 1, 3] + verify_model(_test_unflatten(1, (2, 2, 3, 1, 1)), inp) verify_model(_test_unflatten(-2, (2, 2, 3, 1, 1)), inp) From 4d26be731c1c0ecc7f45f3eab139678b812bb1b8 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Mon, 20 Nov 2023 14:43:45 +0900 Subject: [PATCH 5/5] formatting --- python/tvm/relay/frontend/pytorch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 473aae19d1b2..faed052a034b 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -1553,7 +1553,7 @@ def unflatten(self, inputs, input_types): dshape = get_const_tuple(self.infer_shape_with_prelude(data)) dim = dim if dim >= 0 else len(dshape) + dim - assert len(dshape) > dim and dim >= 0 + assert len(dshape) > dim >= 0 assert unflattened_size.count(-1) <= 1