From e36080446a4ec00c7281e73b1cdeb2a02ff33094 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Fri, 4 Apr 2025 11:32:08 +0900 Subject: [PATCH 1/4] fix test_flatten --- .../torch/base_fx_graph_translator.py | 19 +++++++++++++++++++ .../torch/exported_program_translator.py | 1 + .../tvm/relax/frontend/torch/fx_translator.py | 18 ------------------ .../test_frontend_from_exported_program.py | 4 ---- 4 files changed, 20 insertions(+), 22 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 890f925079e0..c334886db1e5 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -19,6 +19,7 @@ # pylint: disable=import-outside-toplevel """Base class for PyTorch FX Graph importer.""" import abc +from functools import reduce import math from typing import Callable, Dict, Optional, Tuple, Union @@ -1018,6 +1019,24 @@ def _expand_as(self, node: fx.Node) -> relax.Var: other_shape = self.shape_of(args[1]) # the shape of 'other' return self.block_builder.emit(relax.op.broadcast_to(data, other_shape)) + def _flatten_impl(self, x, start_dim, end_dim) -> relax.Var: + shape = self.shape_of(x) + start_dim = start_dim if start_dim >= 0 else len(shape) + start_dim + end_dim = end_dim if end_dim >= 0 else len(shape) + end_dim + flattened = reduce(lambda x, y: x * y, [shape[i] for i in range(start_dim, end_dim + 1)]) + new_shape = ( + [shape[i] for i in range(0, start_dim)] + + [flattened] + + [shape[i] for i in range(end_dim + 1, len(shape))] + ) + return self.block_builder.emit(relax.op.reshape(x, new_shape)) + + def _flatten(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + start_dim = node.args[1] if len(node.args) >= 2 else node.kwargs.get("start_dim", 0) + end_dim = node.args[2] if len(node.args) == 3 else node.kwargs.get("end_dim", -1) + return self._flatten_impl(x, start_dim, end_dim) + def _flip(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] dims = node.args[1] if len(node.args) > 1 else node.kwargs.get("dims", None) diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 2e7c682aa34b..e1be71afd6e4 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -377,6 +377,7 @@ def create_convert_map( "cumprod.default": self._cumprod, "expand.default": self._expand, "expand_as.default": self._expand_as, + "flatten.using_ints": self._flatten, "flip.default": self._flip, "gather.default": self._gather, "permute.default": self._permute, diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 3ddf919c2ed1..e79c1dbc48fa 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -415,24 +415,6 @@ def _chunk(self, node: fx.Node) -> relax.Var: dim = node.args[2] if len(node.args) > 2 else node.kwargs.get("dim", 0) return self.block_builder.emit(relax.op.split(x, chunks, dim)) - def _flatten_impl(self, x, start_dim, end_dim) -> relax.Var: - shape = self.shape_of(x) - start_dim = start_dim if start_dim >= 0 else len(shape) + start_dim - end_dim = end_dim if end_dim >= 0 else len(shape) + end_dim - flattened = reduce(lambda x, y: x * y, [shape[i] for i in range(start_dim, end_dim + 1)]) - new_shape = ( - [shape[i] for i in range(0, start_dim)] - + [flattened] - + [shape[i] for i in range(end_dim + 1, len(shape))] - ) - return self.block_builder.emit(relax.op.reshape(x, new_shape)) - - def _flatten(self, node: fx.Node) -> relax.Var: - x = self.env[node.args[0]] - start_dim = node.args[1] if len(node.args) >= 2 else node.kwargs.get("start_dim", 0) - end_dim = node.args[2] if len(node.args) == 3 else node.kwargs.get("end_dim", -1) - return self._flatten_impl(x, start_dim, end_dim) - def _flatten_module(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] module = self.named_modules[node.target] diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 2175f9aa391c..e34eb729b3f3 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -2702,10 +2702,6 @@ def main( verify_model(Expand2(), example_args, {}, expected1) -@pytest.mark.skipif( - version.parse(torch_version) >= version.parse("2.6.0"), - reason="Tests not compatible with PyTorch >= 2.6", -) def test_flatten(): class Flatten(Module): def __init__(self): From eb0455bdda34c6d28d5f0876f9a84d8cc3cffebb Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Fri, 4 Apr 2025 11:33:06 +0900 Subject: [PATCH 2/4] re-enable test_split --- tests/python/relax/test_frontend_from_exported_program.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index e34eb729b3f3..f853bc5b7151 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -2903,10 +2903,6 @@ def main( verify_model(Slice2(), example_args, {}, expected2) -@pytest.mark.skipif( - version.parse(torch_version) >= version.parse("2.6.0"), - reason="Tests not compatible with PyTorch >= 2.6", -) def test_split(): class Chunk(Module): def forward(self, input): From 86ca78cf847115dc64dbe28d3e5f2a3e90ac3496 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Fri, 4 Apr 2025 11:40:50 +0900 Subject: [PATCH 3/4] fix test_to_copy --- .../frontend/torch/base_fx_graph_translator.py | 15 +++++++++++++++ .../frontend/torch/exported_program_translator.py | 3 +++ .../relax/test_frontend_from_exported_program.py | 7 ++----- 3 files changed, 20 insertions(+), 5 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index c334886db1e5..d99411bd5658 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -1252,6 +1252,21 @@ def _new_ones(self, node: fx.Node) -> relax.Var: ) ) + ########## DataType ########## + + def _to(self, node: fx.Node) -> relax.Var: + import torch + + x = self.env[node.args[0]] + if len(node.args) == 2: + if isinstance(node.args[1], torch.dtype): + dtype = BaseFXGraphImporter._convert_data_type(node.args[1], self.env) + return self.block_builder.emit(relax.op.astype(x, dtype)) + elif "dtype" in node.kwargs: + dtype = BaseFXGraphImporter._convert_data_type(node.kwargs["dtype"], self.env) + return self.block_builder.emit(relax.op.astype(x, dtype)) + return x + ########## Others ########## def _getitem(self, node: fx.Node) -> relax.Var: diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index e1be71afd6e4..26121ecdea10 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -412,6 +412,9 @@ def create_convert_map( "lift_fresh_copy.default": self._to_copy, "new_ones.default": self._new_ones, "one_hot.default": self._one_hot, + # datatype + "to.dtype": self._to, + "to.dtype_layout": self._to, # other "getitem": self._getitem, } diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index f853bc5b7151..8281f2b98676 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -3332,10 +3332,6 @@ def main( verify_model(NewOnes(), example_args, {}, expected1) -@pytest.mark.skipif( - version.parse(torch_version) >= version.parse("2.6.0"), - reason="Tests not compatible with PyTorch >= 2.6", -) def test_to_copy(): # float class ToFloat(Module): @@ -3386,7 +3382,8 @@ def main( ) -> R.Tuple(R.Tensor((1, 2, 3, 4), dtype="float32")): # block 0 with R.dataflow(): - gv: R.Tuple(R.Tensor((1, 2, 3, 4), dtype="float32")) = (x,) + lv: R.Tensor((1, 2, 3, 4), dtype="float32") = R.astype(x, dtype="float32") + gv: R.Tuple(R.Tensor((1, 2, 3, 4), dtype="float32")) = (lv,) R.output(gv) return gv From 373174d94975679ca8b2e9a364e6ffb7d72c9f57 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Fri, 4 Apr 2025 11:41:22 +0900 Subject: [PATCH 4/4] re-enable test_batchnorm2d --- tests/python/relax/test_frontend_from_exported_program.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 8281f2b98676..cc2f669d32e0 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -1021,10 +1021,6 @@ def main( verify_model(Min1(), example_args1, {}, expected_min1) -@pytest.mark.skipif( - version.parse(torch_version) >= version.parse("2.6.0"), - reason="Tests not compatible with PyTorch >= 2.6", -) def test_batchnorm2d(): class BatchNorm2d(Module): def __init__(self):