From 2250b803521a1bf13702c3dac6cdad1e484d10f2 Mon Sep 17 00:00:00 2001 From: Pratheesh Date: Tue, 15 Apr 2025 09:32:50 +0000 Subject: [PATCH 1/7] add op support for roll op --- .../torch/base_fx_graph_translator.py | 68 ++++++++++ .../torch/exported_program_translator.py | 1 + .../tvm/relax/frontend/torch/fx_translator.py | 1 + .../test_frontend_from_exported_program.py | 125 ++++++++++++++++++ tests/python/relax/test_frontend_from_fx.py | 125 ++++++++++++++++++ 5 files changed, 320 insertions(+) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 2652b167e5c0..82162d72821e 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -21,7 +21,12 @@ import abc from functools import reduce import math +<<<<<<< HEAD from typing import Callable, Dict, Optional, Tuple, Union, List +======= +from typing import Callable, Dict, Optional, Tuple, Union +import tvm +>>>>>>> 20cb5dd08 (add op support for roll op) from tvm import relax @@ -1163,6 +1168,69 @@ def _repeat(self, node: fx.Node) -> relax.Var: x = args[0] dims = args[1] if isinstance(args[1], (torch.Size, tuple, list)) else args[1:] return self.block_builder.emit(relax.op.tile(x, dims)) + + def _roll(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + input_tensor = args[0] + shifts = args[1] + dims = args[2] if len(args) > 2 else None + + # Get original shape + original_shape = self.shape_of(input_tensor) + + def to_int(val): + if isinstance(val, tvm.tir.IntImm): + return int(val.value) + elif isinstance(val, int): + return val + elif hasattr(val, '__int__'): + return int(val) + raise TypeError(f"Unsupported type for shift/dim: {type(val)}") + + def roll_single_dim(tensor: relax.Var, shift: int, dim: int) -> relax.Var: + shape = self.shape_of(tensor) + + dim_size = shape.values[dim] + shift_val = to_int(shift) + dim_size_val = to_int(dim_size) + shift_mod = shift_val % dim_size_val + if shift_mod == 0: + return tensor + + split_pos = dim_size_val - shift_mod + part1 = self.block_builder.emit(relax.op.strided_slice(tensor,axes=[dim],begin=[0],end=[split_pos],strides=[1],)) + part2 = self.block_builder.emit(relax.op.strided_slice(tensor,axes=[dim],begin=[split_pos],end=[dim_size_val],strides=[1],)) + return self.block_builder.emit(relax.op.concat([part2, part1], axis=dim)) + + # Handle dims=None (flatten -> roll -> reshape) + if dims is None: + flattened = self.block_builder.emit(relax.op.reshape(input_tensor, (-1,))) + shift_scalar = to_int(shifts[0] if isinstance(shifts, (list, tuple)) else shifts) + rolled = roll_single_dim(flattened, shift_scalar, 0) + return self.block_builder.emit(relax.op.reshape(rolled, original_shape)) + + # Normalize shifts and dims + if isinstance(shifts, (list, tuple)): + shifts = [to_int(s) for s in shifts] + else: + shifts = [to_int(shifts)] + + if isinstance(dims, (list, tuple)): + dims = [to_int(d) for d in dims] + else: + dims = [to_int(dims)] + + if len(shifts) != len(dims): + raise ValueError("shifts and dims must have the same length") + + result = input_tensor + rank = len(original_shape.values) + for shift, dim in zip(shifts, dims): + if dim < 0: + dim += rank + result = roll_single_dim(result, shift, dim) + + return result def _reshape(self, node: fx.Node) -> relax.Var: import torch # type: ignore diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 5d4f3437b257..7c7973e8e49c 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -423,6 +423,7 @@ def create_convert_map( "narrow.default": self._narrow, "permute.default": self._permute, "repeat.default": self._repeat, + "roll.default":self._roll, "select.int": self._select, "slice.Tensor": self._slice, "split.Tensor": self._split, diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index e6b1fdd223ea..5a34befb9296 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -750,6 +750,7 @@ def create_convert_map( "numel": self._numel, "permute": self._permute, "repeat": self._repeat, + "roll": self._roll, "reshape": self._reshape, "scatter": self._scatter, "select": self._select, diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 9259936dc223..d34aa57904b8 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -2968,6 +2968,131 @@ def main( verify_model(ReshapeAs(), example_args, {}, expected1) +def test_roll(): + class Roll1(Module): + def forward(self, x): + return torch.roll(x, 1) + + class Roll2(Module): + def forward(self, x): + return torch.roll(x, -1, 0) + + class Roll3(Module): + def forward(self, x): + return torch.roll(x, shifts=(2, 1), dims=(0, 1)) + + # Test case 1: torch.roll(x, 1) + @I.ir_module + class Expected1: + @R.function + def main(x: R.Tensor((4, 2), dtype="int64")) -> R.Tuple(R.Tensor((4, 2), dtype="int64")): + with R.dataflow(): + lv: R.Tensor((8,), dtype="int64") = R.reshape(x, R.shape([8])) + lv1: R.Tensor((7,), dtype="int64") = R.strided_slice( + lv, + axes=[0], + begin=[R.prim_value(0)], + end=[R.prim_value(7)], + strides=[R.prim_value(1)], + assume_inbound=False + ) + lv2: R.Tensor((1,), dtype="int64") = R.strided_slice( + lv, + axes=[0], + begin=[R.prim_value(7)], + end=[R.prim_value(8)], + strides=[R.prim_value(1)], + assume_inbound=False + ) + lv3: R.Tensor((8,), dtype="int64") = R.concat((lv2, lv1), axis=0) + lv4: R.Tensor((4, 2), dtype="int64") = R.reshape(lv3, R.shape([4, 2])) + gv: R.Tuple(R.Tensor((4, 2), dtype="int64")) = (lv4,) + R.output(gv) + return gv + + # Test case 2: torch.roll(x, -1, 0) + @I.ir_module + class Expected2: + @R.function + def main(x: R.Tensor((4, 2), dtype="int64")) -> R.Tuple(R.Tensor((4, 2), dtype="int64")): + with R.dataflow(): + lv: R.Tensor((1, 2), dtype="int64") = R.strided_slice( + x, + axes=[0], + begin=[R.prim_value(0)], + end=[R.prim_value(1)], + strides=[R.prim_value(1)], + assume_inbound=False + ) + lv1: R.Tensor((3, 2), dtype="int64") = R.strided_slice( + x, + axes=[0], + begin=[R.prim_value(1)], + end=[R.prim_value(4)], + strides=[R.prim_value(1)], + assume_inbound=False + ) + lv2: R.Tensor((4, 2), dtype="int64") = R.concat((lv1, lv), axis=0) + gv: R.Tuple(R.Tensor((4, 2), dtype="int64")) = (lv2,) + R.output(gv) + return gv + + # Test case 3: torch.roll(x, shifts=(2,1), dims=(0,1)) + @I.ir_module + class Expected3: + @R.function + def main(x: R.Tensor((4, 2), dtype="int64")) -> R.Tuple(R.Tensor((4, 2), dtype="int64")): + with R.dataflow(): + # First roll along dim=0 with shift=2 + lv: R.Tensor((2, 2), dtype="int64") = R.strided_slice( + x, + axes=[0], + begin=[R.prim_value(0)], + end=[R.prim_value(2)], + strides=[R.prim_value(1)], + assume_inbound=False + ) + lv1: R.Tensor((2, 2), dtype="int64") = R.strided_slice( + x, + axes=[0], + begin=[R.prim_value(2)], + end=[R.prim_value(4)], + strides=[R.prim_value(1)], + assume_inbound=False + ) + lv2: R.Tensor((4, 2), dtype="int64") = R.concat((lv1, lv), axis=0) + + # Second roll along dim=1 with shift=1 + lv3: R.Tensor((4, 1), dtype="int64") = R.strided_slice( + lv2, + axes=[1], + begin=[R.prim_value(0)], + end=[R.prim_value(1)], + strides=[R.prim_value(1)], + assume_inbound=False + ) + lv4: R.Tensor((4, 1), dtype="int64") = R.strided_slice( + lv2, + axes=[1], + begin=[R.prim_value(1)], + end=[R.prim_value(2)], + strides=[R.prim_value(1)], + assume_inbound=False + ) + lv5: R.Tensor((4, 2), dtype="int64") = R.concat((lv4, lv3), axis=1) + gv: R.Tuple(R.Tensor((4, 2), dtype="int64")) = (lv5,) + R.output(gv) + return gv + + # Test inputs + example_input = torch.randint(0, 10, (4, 2), dtype=torch.int64) + + # Run verification for each case + verify_model(Roll1(), (example_input,), {}, Expected1) + verify_model(Roll2(), (example_input,), {}, Expected2) + verify_model(Roll3(), (example_input,), {}, Expected3) + + def test_select_slice(): class Slice1(Module): def forward(self, x): diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index 53c925e14ee6..5c92cca9ca7e 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -3560,6 +3560,131 @@ def main(x: R.Tensor((1, 3), dtype="float32")) -> R.Tensor((4, 6), dtype="float3 verify_model(Tile2(), [(torch.Size([1, 3]), "float32")], {}, expected2) +def test_roll(): + class Roll1(Module): + def forward(self, x): + return torch.roll(x, 1) + + class Roll2(Module): + def forward(self, x): + return torch.roll(x, -1, 0) + + class Roll3(Module): + def forward(self, x): + return torch.roll(x, shifts=(2, 1), dims=(0, 1)) + + # Test case 1: torch.roll(x, 1) + @I.ir_module + class Expected1: + @R.function + def main(x: R.Tensor((4, 2), dtype="int64")) -> R.Tuple(R.Tensor((4, 2), dtype="int64")): + with R.dataflow(): + lv: R.Tensor((8,), dtype="int64") = R.reshape(x, R.shape([8])) + lv1: R.Tensor((7,), dtype="int64") = R.strided_slice( + lv, + axes=[0], + begin=[R.prim_value(0)], + end=[R.prim_value(7)], + strides=[R.prim_value(1)], + assume_inbound=False + ) + lv2: R.Tensor((1,), dtype="int64") = R.strided_slice( + lv, + axes=[0], + begin=[R.prim_value(7)], + end=[R.prim_value(8)], + strides=[R.prim_value(1)], + assume_inbound=False + ) + lv3: R.Tensor((8,), dtype="int64") = R.concat((lv2, lv1), axis=0) + lv4: R.Tensor((4, 2), dtype="int64") = R.reshape(lv3, R.shape([4, 2])) + gv: R.Tuple(R.Tensor((4, 2), dtype="int64")) = (lv4,) + R.output(gv) + return gv + + # Test case 2: torch.roll(x, -1, 0) + @I.ir_module + class Expected2: + @R.function + def main(x: R.Tensor((4, 2), dtype="int64")) -> R.Tuple(R.Tensor((4, 2), dtype="int64")): + with R.dataflow(): + lv: R.Tensor((1, 2), dtype="int64") = R.strided_slice( + x, + axes=[0], + begin=[R.prim_value(0)], + end=[R.prim_value(1)], + strides=[R.prim_value(1)], + assume_inbound=False + ) + lv1: R.Tensor((3, 2), dtype="int64") = R.strided_slice( + x, + axes=[0], + begin=[R.prim_value(1)], + end=[R.prim_value(4)], + strides=[R.prim_value(1)], + assume_inbound=False + ) + lv2: R.Tensor((4, 2), dtype="int64") = R.concat((lv1, lv), axis=0) + gv: R.Tuple(R.Tensor((4, 2), dtype="int64")) = (lv2,) + R.output(gv) + return gv + + # Test case 3: torch.roll(x, shifts=(2,1), dims=(0,1)) + @I.ir_module + class Expected3: + @R.function + def main(x: R.Tensor((4, 2), dtype="int64")) -> R.Tuple(R.Tensor((4, 2), dtype="int64")): + with R.dataflow(): + # First roll along dim=0 with shift=2 + lv: R.Tensor((2, 2), dtype="int64") = R.strided_slice( + x, + axes=[0], + begin=[R.prim_value(0)], + end=[R.prim_value(2)], + strides=[R.prim_value(1)], + assume_inbound=False + ) + lv1: R.Tensor((2, 2), dtype="int64") = R.strided_slice( + x, + axes=[0], + begin=[R.prim_value(2)], + end=[R.prim_value(4)], + strides=[R.prim_value(1)], + assume_inbound=False + ) + lv2: R.Tensor((4, 2), dtype="int64") = R.concat((lv1, lv), axis=0) + + # Second roll along dim=1 with shift=1 + lv3: R.Tensor((4, 1), dtype="int64") = R.strided_slice( + lv2, + axes=[1], + begin=[R.prim_value(0)], + end=[R.prim_value(1)], + strides=[R.prim_value(1)], + assume_inbound=False + ) + lv4: R.Tensor((4, 1), dtype="int64") = R.strided_slice( + lv2, + axes=[1], + begin=[R.prim_value(1)], + end=[R.prim_value(2)], + strides=[R.prim_value(1)], + assume_inbound=False + ) + lv5: R.Tensor((4, 2), dtype="int64") = R.concat((lv4, lv3), axis=1) + gv: R.Tuple(R.Tensor((4, 2), dtype="int64")) = (lv5,) + R.output(gv) + return gv + + # Test inputs + example_input = torch.randint(0, 10, (4, 2), dtype=torch.int64) + + # Run verification for each case + verify_model(Roll1(), (example_input,), {}, Expected1) + verify_model(Roll2(), (example_input,), {}, Expected2) + verify_model(Roll3(), (example_input,), {}, Expected3) + + def test_view(): input_info = [([1, 2, 3, 4], "float32")] From ab0411d531c1a4bdeea46a7c2d7172d63187b0a1 Mon Sep 17 00:00:00 2001 From: Pratheesh Date: Tue, 15 Apr 2025 10:21:15 +0000 Subject: [PATCH 2/7] lint fix --- .../torch/base_fx_graph_translator.py | 23 +++++++++++++++---- .../torch/exported_program_translator.py | 2 +- .../test_frontend_from_exported_program.py | 18 +++++++-------- tests/python/relax/test_frontend_from_fx.py | 18 +++++++-------- 4 files changed, 38 insertions(+), 23 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 82162d72821e..d9ddbe76afe8 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -1168,7 +1168,7 @@ def _repeat(self, node: fx.Node) -> relax.Var: x = args[0] dims = args[1] if isinstance(args[1], (torch.Size, tuple, list)) else args[1:] return self.block_builder.emit(relax.op.tile(x, dims)) - + def _roll(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) input_tensor = args[0] @@ -1183,7 +1183,7 @@ def to_int(val): return int(val.value) elif isinstance(val, int): return val - elif hasattr(val, '__int__'): + elif hasattr(val, "__int__"): return int(val) raise TypeError(f"Unsupported type for shift/dim: {type(val)}") @@ -1198,8 +1198,23 @@ def roll_single_dim(tensor: relax.Var, shift: int, dim: int) -> relax.Var: return tensor split_pos = dim_size_val - shift_mod - part1 = self.block_builder.emit(relax.op.strided_slice(tensor,axes=[dim],begin=[0],end=[split_pos],strides=[1],)) - part2 = self.block_builder.emit(relax.op.strided_slice(tensor,axes=[dim],begin=[split_pos],end=[dim_size_val],strides=[1],)) + part1 = self.block_builder.emit( + relax.op.strided_slice( + tensor, + axes=[dim], + begin=[0], + end=[split_pos], + strides=[1], + ) + ) + part2 = self.block_builder.emit( + relax.op.strided_slice( + tensor,axes=[dim], + begin=[split_pos], + end=[dim_size_val], + strides=[1], + ) + ) return self.block_builder.emit(relax.op.concat([part2, part1], axis=dim)) # Handle dims=None (flatten -> roll -> reshape) diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 7c7973e8e49c..932607287571 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -423,7 +423,7 @@ def create_convert_map( "narrow.default": self._narrow, "permute.default": self._permute, "repeat.default": self._repeat, - "roll.default":self._roll, + "roll.default": self._roll, "select.int": self._select, "slice.Tensor": self._slice, "split.Tensor": self._split, diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index d34aa57904b8..80c0bd5fb4f5 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -2994,7 +2994,7 @@ def main(x: R.Tensor((4, 2), dtype="int64")) -> R.Tuple(R.Tensor((4, 2), dtype=" begin=[R.prim_value(0)], end=[R.prim_value(7)], strides=[R.prim_value(1)], - assume_inbound=False + assume_inbound=False, ) lv2: R.Tensor((1,), dtype="int64") = R.strided_slice( lv, @@ -3002,7 +3002,7 @@ def main(x: R.Tensor((4, 2), dtype="int64")) -> R.Tuple(R.Tensor((4, 2), dtype=" begin=[R.prim_value(7)], end=[R.prim_value(8)], strides=[R.prim_value(1)], - assume_inbound=False + assume_inbound=False, ) lv3: R.Tensor((8,), dtype="int64") = R.concat((lv2, lv1), axis=0) lv4: R.Tensor((4, 2), dtype="int64") = R.reshape(lv3, R.shape([4, 2])) @@ -3022,7 +3022,7 @@ def main(x: R.Tensor((4, 2), dtype="int64")) -> R.Tuple(R.Tensor((4, 2), dtype=" begin=[R.prim_value(0)], end=[R.prim_value(1)], strides=[R.prim_value(1)], - assume_inbound=False + assume_inbound=False, ) lv1: R.Tensor((3, 2), dtype="int64") = R.strided_slice( x, @@ -3030,7 +3030,7 @@ def main(x: R.Tensor((4, 2), dtype="int64")) -> R.Tuple(R.Tensor((4, 2), dtype=" begin=[R.prim_value(1)], end=[R.prim_value(4)], strides=[R.prim_value(1)], - assume_inbound=False + assume_inbound=False, ) lv2: R.Tensor((4, 2), dtype="int64") = R.concat((lv1, lv), axis=0) gv: R.Tuple(R.Tensor((4, 2), dtype="int64")) = (lv2,) @@ -3050,7 +3050,7 @@ def main(x: R.Tensor((4, 2), dtype="int64")) -> R.Tuple(R.Tensor((4, 2), dtype=" begin=[R.prim_value(0)], end=[R.prim_value(2)], strides=[R.prim_value(1)], - assume_inbound=False + assume_inbound=False, ) lv1: R.Tensor((2, 2), dtype="int64") = R.strided_slice( x, @@ -3058,10 +3058,10 @@ def main(x: R.Tensor((4, 2), dtype="int64")) -> R.Tuple(R.Tensor((4, 2), dtype=" begin=[R.prim_value(2)], end=[R.prim_value(4)], strides=[R.prim_value(1)], - assume_inbound=False + assume_inbound=False, ) lv2: R.Tensor((4, 2), dtype="int64") = R.concat((lv1, lv), axis=0) - + # Second roll along dim=1 with shift=1 lv3: R.Tensor((4, 1), dtype="int64") = R.strided_slice( lv2, @@ -3069,7 +3069,7 @@ def main(x: R.Tensor((4, 2), dtype="int64")) -> R.Tuple(R.Tensor((4, 2), dtype=" begin=[R.prim_value(0)], end=[R.prim_value(1)], strides=[R.prim_value(1)], - assume_inbound=False + assume_inbound=False, ) lv4: R.Tensor((4, 1), dtype="int64") = R.strided_slice( lv2, @@ -3077,7 +3077,7 @@ def main(x: R.Tensor((4, 2), dtype="int64")) -> R.Tuple(R.Tensor((4, 2), dtype=" begin=[R.prim_value(1)], end=[R.prim_value(2)], strides=[R.prim_value(1)], - assume_inbound=False + assume_inbound=False, ) lv5: R.Tensor((4, 2), dtype="int64") = R.concat((lv4, lv3), axis=1) gv: R.Tuple(R.Tensor((4, 2), dtype="int64")) = (lv5,) diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index 5c92cca9ca7e..f38a7cff0e03 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -3586,7 +3586,7 @@ def main(x: R.Tensor((4, 2), dtype="int64")) -> R.Tuple(R.Tensor((4, 2), dtype=" begin=[R.prim_value(0)], end=[R.prim_value(7)], strides=[R.prim_value(1)], - assume_inbound=False + assume_inbound=False, ) lv2: R.Tensor((1,), dtype="int64") = R.strided_slice( lv, @@ -3594,7 +3594,7 @@ def main(x: R.Tensor((4, 2), dtype="int64")) -> R.Tuple(R.Tensor((4, 2), dtype=" begin=[R.prim_value(7)], end=[R.prim_value(8)], strides=[R.prim_value(1)], - assume_inbound=False + assume_inbound=False, ) lv3: R.Tensor((8,), dtype="int64") = R.concat((lv2, lv1), axis=0) lv4: R.Tensor((4, 2), dtype="int64") = R.reshape(lv3, R.shape([4, 2])) @@ -3614,7 +3614,7 @@ def main(x: R.Tensor((4, 2), dtype="int64")) -> R.Tuple(R.Tensor((4, 2), dtype=" begin=[R.prim_value(0)], end=[R.prim_value(1)], strides=[R.prim_value(1)], - assume_inbound=False + assume_inbound=False, ) lv1: R.Tensor((3, 2), dtype="int64") = R.strided_slice( x, @@ -3622,7 +3622,7 @@ def main(x: R.Tensor((4, 2), dtype="int64")) -> R.Tuple(R.Tensor((4, 2), dtype=" begin=[R.prim_value(1)], end=[R.prim_value(4)], strides=[R.prim_value(1)], - assume_inbound=False + assume_inbound=False, ) lv2: R.Tensor((4, 2), dtype="int64") = R.concat((lv1, lv), axis=0) gv: R.Tuple(R.Tensor((4, 2), dtype="int64")) = (lv2,) @@ -3642,7 +3642,7 @@ def main(x: R.Tensor((4, 2), dtype="int64")) -> R.Tuple(R.Tensor((4, 2), dtype=" begin=[R.prim_value(0)], end=[R.prim_value(2)], strides=[R.prim_value(1)], - assume_inbound=False + assume_inbound=False, ) lv1: R.Tensor((2, 2), dtype="int64") = R.strided_slice( x, @@ -3650,10 +3650,10 @@ def main(x: R.Tensor((4, 2), dtype="int64")) -> R.Tuple(R.Tensor((4, 2), dtype=" begin=[R.prim_value(2)], end=[R.prim_value(4)], strides=[R.prim_value(1)], - assume_inbound=False + assume_inbound=False, ) lv2: R.Tensor((4, 2), dtype="int64") = R.concat((lv1, lv), axis=0) - + # Second roll along dim=1 with shift=1 lv3: R.Tensor((4, 1), dtype="int64") = R.strided_slice( lv2, @@ -3661,7 +3661,7 @@ def main(x: R.Tensor((4, 2), dtype="int64")) -> R.Tuple(R.Tensor((4, 2), dtype=" begin=[R.prim_value(0)], end=[R.prim_value(1)], strides=[R.prim_value(1)], - assume_inbound=False + assume_inbound=False, ) lv4: R.Tensor((4, 1), dtype="int64") = R.strided_slice( lv2, @@ -3669,7 +3669,7 @@ def main(x: R.Tensor((4, 2), dtype="int64")) -> R.Tuple(R.Tensor((4, 2), dtype=" begin=[R.prim_value(1)], end=[R.prim_value(2)], strides=[R.prim_value(1)], - assume_inbound=False + assume_inbound=False, ) lv5: R.Tensor((4, 2), dtype="int64") = R.concat((lv4, lv3), axis=1) gv: R.Tuple(R.Tensor((4, 2), dtype="int64")) = (lv5,) From 2dd9f66169254f84ccfbbc868e5db9da177aa145 Mon Sep 17 00:00:00 2001 From: Pratheesh Date: Tue, 15 Apr 2025 17:13:57 +0000 Subject: [PATCH 3/7] fixed unity check --- python/tvm/relax/frontend/torch/base_fx_graph_translator.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index d9ddbe76afe8..6eb8748bf0be 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -1172,8 +1172,8 @@ def _repeat(self, node: fx.Node) -> relax.Var: def _roll(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) input_tensor = args[0] - shifts = args[1] - dims = args[2] if len(args) > 2 else None + shifts = args[1] if len(node.args) > 1 else node.kwargs.get("shifts", None) + dims = args[2] if len(node.args) > 2 else node.kwargs.get("dims", None) # Get original shape original_shape = self.shape_of(input_tensor) From 9a57bda46363b082d70f42f38d4e4991199aa012 Mon Sep 17 00:00:00 2001 From: Pratheesh Date: Wed, 16 Apr 2025 04:03:58 +0000 Subject: [PATCH 4/7] add unit test in fx_graph --- tests/python/relax/test_frontend_from_fx.py | 99 +++++++-------------- 1 file changed, 32 insertions(+), 67 deletions(-) diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index f38a7cff0e03..b77a55c938bf 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -3577,28 +3577,20 @@ def forward(self, x): @I.ir_module class Expected1: @R.function - def main(x: R.Tensor((4, 2), dtype="int64")) -> R.Tuple(R.Tensor((4, 2), dtype="int64")): + def main(inp_0: R.Tensor((4, 2), dtype="int64")) -> R.Tensor((4, 2), dtype="int64"): with R.dataflow(): - lv: R.Tensor((8,), dtype="int64") = R.reshape(x, R.shape([8])) + lv: R.Tensor((8,), dtype="int64") = R.reshape(inp_0, R.shape([8])) lv1: R.Tensor((7,), dtype="int64") = R.strided_slice( - lv, - axes=[0], - begin=[R.prim_value(0)], - end=[R.prim_value(7)], - strides=[R.prim_value(1)], - assume_inbound=False, + lv, axes=[0], begin=[R.prim_value(0)], end=[R.prim_value(7)], + strides=[R.prim_value(1)], assume_inbound=False ) lv2: R.Tensor((1,), dtype="int64") = R.strided_slice( - lv, - axes=[0], - begin=[R.prim_value(7)], - end=[R.prim_value(8)], - strides=[R.prim_value(1)], - assume_inbound=False, + lv, axes=[0], begin=[R.prim_value(7)], end=[R.prim_value(8)], + strides=[R.prim_value(1)], assume_inbound=False ) lv3: R.Tensor((8,), dtype="int64") = R.concat((lv2, lv1), axis=0) lv4: R.Tensor((4, 2), dtype="int64") = R.reshape(lv3, R.shape([4, 2])) - gv: R.Tuple(R.Tensor((4, 2), dtype="int64")) = (lv4,) + gv: R.Tensor((4, 2), dtype="int64") = lv4 R.output(gv) return gv @@ -3606,83 +3598,56 @@ def main(x: R.Tensor((4, 2), dtype="int64")) -> R.Tuple(R.Tensor((4, 2), dtype=" @I.ir_module class Expected2: @R.function - def main(x: R.Tensor((4, 2), dtype="int64")) -> R.Tuple(R.Tensor((4, 2), dtype="int64")): + def main(inp_0: R.Tensor((4, 2), dtype="int64")) -> R.Tensor((4, 2), dtype="int64"): with R.dataflow(): lv: R.Tensor((1, 2), dtype="int64") = R.strided_slice( - x, - axes=[0], - begin=[R.prim_value(0)], - end=[R.prim_value(1)], - strides=[R.prim_value(1)], - assume_inbound=False, - ) + inp_0, axes=[0], begin=[R.prim_value(0)], end=[R.prim_value(1)], + strides=[R.prim_value(1)], assume_inbound=False + ) # Row 0: [[a, b]] lv1: R.Tensor((3, 2), dtype="int64") = R.strided_slice( - x, - axes=[0], - begin=[R.prim_value(1)], - end=[R.prim_value(4)], - strides=[R.prim_value(1)], - assume_inbound=False, + inp_0, axes=[0], begin=[R.prim_value(1)], end=[R.prim_value(4)], + strides=[R.prim_value(1)], assume_inbound=False ) - lv2: R.Tensor((4, 2), dtype="int64") = R.concat((lv1, lv), axis=0) - gv: R.Tuple(R.Tensor((4, 2), dtype="int64")) = (lv2,) + lv2: R.Tensor((4, 2), dtype="int64") = R.concat((lv1, lv), axis=0) # [[c, d], [e, f], [g, h], [a, b]] + gv: R.Tensor((4, 2), dtype="int64") = lv2 R.output(gv) return gv - # Test case 3: torch.roll(x, shifts=(2,1), dims=(0,1)) + # Test case 3: torch.roll(x, shifts=(2, 1), dims=(0, 1)) @I.ir_module class Expected3: @R.function - def main(x: R.Tensor((4, 2), dtype="int64")) -> R.Tuple(R.Tensor((4, 2), dtype="int64")): + def main(inp_0: R.Tensor((4, 2), dtype="int64")) -> R.Tensor((4, 2), dtype="int64"): with R.dataflow(): - # First roll along dim=0 with shift=2 lv: R.Tensor((2, 2), dtype="int64") = R.strided_slice( - x, - axes=[0], - begin=[R.prim_value(0)], - end=[R.prim_value(2)], - strides=[R.prim_value(1)], - assume_inbound=False, + inp_0, axes=[0], begin=[R.prim_value(0)], end=[R.prim_value(2)], + strides=[R.prim_value(1)], assume_inbound=False ) lv1: R.Tensor((2, 2), dtype="int64") = R.strided_slice( - x, - axes=[0], - begin=[R.prim_value(2)], - end=[R.prim_value(4)], - strides=[R.prim_value(1)], - assume_inbound=False, + inp_0, axes=[0], begin=[R.prim_value(2)], end=[R.prim_value(4)], + strides=[R.prim_value(1)], assume_inbound=False ) - lv2: R.Tensor((4, 2), dtype="int64") = R.concat((lv1, lv), axis=0) - - # Second roll along dim=1 with shift=1 + lv2: R.Tensor((4, 2), dtype="int64") = R.concat((lv1, lv), axis=0) # [[e, f], [g, h], [a, b], [c, d]] lv3: R.Tensor((4, 1), dtype="int64") = R.strided_slice( - lv2, - axes=[1], - begin=[R.prim_value(0)], - end=[R.prim_value(1)], - strides=[R.prim_value(1)], - assume_inbound=False, + lv2, axes=[1], begin=[R.prim_value(0)], end=[R.prim_value(1)], + strides=[R.prim_value(1)], assume_inbound=False ) lv4: R.Tensor((4, 1), dtype="int64") = R.strided_slice( - lv2, - axes=[1], - begin=[R.prim_value(1)], - end=[R.prim_value(2)], - strides=[R.prim_value(1)], - assume_inbound=False, + lv2, axes=[1], begin=[R.prim_value(1)], end=[R.prim_value(2)], + strides=[R.prim_value(1)], assume_inbound=False ) - lv5: R.Tensor((4, 2), dtype="int64") = R.concat((lv4, lv3), axis=1) - gv: R.Tuple(R.Tensor((4, 2), dtype="int64")) = (lv5,) + lv5: R.Tensor((4, 2), dtype="int64") = R.concat((lv4, lv3), axis=1) # [[f, e], [h, g], [b, a], [d, c]] + gv: R.Tensor((4, 2), dtype="int64") = lv5 R.output(gv) return gv # Test inputs - example_input = torch.randint(0, 10, (4, 2), dtype=torch.int64) + input_info = [([4, 2], "int64")] # Run verification for each case - verify_model(Roll1(), (example_input,), {}, Expected1) - verify_model(Roll2(), (example_input,), {}, Expected2) - verify_model(Roll3(), (example_input,), {}, Expected3) + verify_model(Roll1(), input_info, {}, Expected1) + verify_model(Roll2(), input_info, {}, Expected2) + verify_model(Roll3(), input_info, {}, Expected3) def test_view(): From 391eb34bede7b2c4e8a5c1887231da4f6ba768ad Mon Sep 17 00:00:00 2001 From: Pratheesh Date: Wed, 16 Apr 2025 06:57:44 +0000 Subject: [PATCH 5/7] lint issues --- .../torch/base_fx_graph_translator.py | 3 +- tests/python/relax/test_frontend_from_fx.py | 74 +++++++++++++------ 2 files changed, 54 insertions(+), 23 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 6eb8748bf0be..e528e0b43b48 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -1209,7 +1209,8 @@ def roll_single_dim(tensor: relax.Var, shift: int, dim: int) -> relax.Var: ) part2 = self.block_builder.emit( relax.op.strided_slice( - tensor,axes=[dim], + tensor, + axes=[dim], begin=[split_pos], end=[dim_size_val], strides=[1], diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index b77a55c938bf..93240d6708c9 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -3581,12 +3581,20 @@ def main(inp_0: R.Tensor((4, 2), dtype="int64")) -> R.Tensor((4, 2), dtype="int6 with R.dataflow(): lv: R.Tensor((8,), dtype="int64") = R.reshape(inp_0, R.shape([8])) lv1: R.Tensor((7,), dtype="int64") = R.strided_slice( - lv, axes=[0], begin=[R.prim_value(0)], end=[R.prim_value(7)], - strides=[R.prim_value(1)], assume_inbound=False + lv, + axes=[0], + begin=[R.prim_value(0)], + end=[R.prim_value(7)], + strides=[R.prim_value(1)], + assume_inbound=False ) lv2: R.Tensor((1,), dtype="int64") = R.strided_slice( - lv, axes=[0], begin=[R.prim_value(7)], end=[R.prim_value(8)], - strides=[R.prim_value(1)], assume_inbound=False + lv, + axes=[0], + begin=[R.prim_value(7)], + end=[R.prim_value(8)], + strides=[R.prim_value(1)], + assume_inbound=False ) lv3: R.Tensor((8,), dtype="int64") = R.concat((lv2, lv1), axis=0) lv4: R.Tensor((4, 2), dtype="int64") = R.reshape(lv3, R.shape([4, 2])) @@ -3601,14 +3609,22 @@ class Expected2: def main(inp_0: R.Tensor((4, 2), dtype="int64")) -> R.Tensor((4, 2), dtype="int64"): with R.dataflow(): lv: R.Tensor((1, 2), dtype="int64") = R.strided_slice( - inp_0, axes=[0], begin=[R.prim_value(0)], end=[R.prim_value(1)], - strides=[R.prim_value(1)], assume_inbound=False - ) # Row 0: [[a, b]] + inp_0, + axes=[0], + begin=[R.prim_value(0)], + end=[R.prim_value(1)], + strides=[R.prim_value(1)], + assume_inbound=False + ) lv1: R.Tensor((3, 2), dtype="int64") = R.strided_slice( - inp_0, axes=[0], begin=[R.prim_value(1)], end=[R.prim_value(4)], - strides=[R.prim_value(1)], assume_inbound=False + inp_0, + axes=[0], + begin=[R.prim_value(1)], + end=[R.prim_value(4)], + strides=[R.prim_value(1)], + assume_inbound=False ) - lv2: R.Tensor((4, 2), dtype="int64") = R.concat((lv1, lv), axis=0) # [[c, d], [e, f], [g, h], [a, b]] + lv2: R.Tensor((4, 2), dtype="int64") = R.concat((lv1, lv), axis=0) gv: R.Tensor((4, 2), dtype="int64") = lv2 R.output(gv) return gv @@ -3620,31 +3636,45 @@ class Expected3: def main(inp_0: R.Tensor((4, 2), dtype="int64")) -> R.Tensor((4, 2), dtype="int64"): with R.dataflow(): lv: R.Tensor((2, 2), dtype="int64") = R.strided_slice( - inp_0, axes=[0], begin=[R.prim_value(0)], end=[R.prim_value(2)], - strides=[R.prim_value(1)], assume_inbound=False + inp_0, + axes=[0], + begin=[R.prim_value(0)], + end=[R.prim_value(2)], + strides=[R.prim_value(1)], + assume_inbound=False ) lv1: R.Tensor((2, 2), dtype="int64") = R.strided_slice( - inp_0, axes=[0], begin=[R.prim_value(2)], end=[R.prim_value(4)], - strides=[R.prim_value(1)], assume_inbound=False + inp_0, + axes=[0], + begin=[R.prim_value(2)], + end=[R.prim_value(4)], + strides=[R.prim_value(1)], + assume_inbound=False ) - lv2: R.Tensor((4, 2), dtype="int64") = R.concat((lv1, lv), axis=0) # [[e, f], [g, h], [a, b], [c, d]] + lv2: R.Tensor((4, 2), dtype="int64") = R.concat((lv1, lv), axis=0) lv3: R.Tensor((4, 1), dtype="int64") = R.strided_slice( - lv2, axes=[1], begin=[R.prim_value(0)], end=[R.prim_value(1)], - strides=[R.prim_value(1)], assume_inbound=False + lv2, + axes=[1], + begin=[R.prim_value(0)], + end=[R.prim_value(1)], + strides=[R.prim_value(1)], + assume_inbound=False ) lv4: R.Tensor((4, 1), dtype="int64") = R.strided_slice( - lv2, axes=[1], begin=[R.prim_value(1)], end=[R.prim_value(2)], - strides=[R.prim_value(1)], assume_inbound=False + lv2, + axes=[1], + begin=[R.prim_value(1)], + end=[R.prim_value(2)], + strides=[R.prim_value(1)], + assume_inbound=False ) - lv5: R.Tensor((4, 2), dtype="int64") = R.concat((lv4, lv3), axis=1) # [[f, e], [h, g], [b, a], [d, c]] + lv5: R.Tensor((4, 2), dtype="int64") = R.concat((lv4, lv3), axis=1) gv: R.Tensor((4, 2), dtype="int64") = lv5 R.output(gv) return gv - # Test inputs input_info = [([4, 2], "int64")] - # Run verification for each case verify_model(Roll1(), input_info, {}, Expected1) verify_model(Roll2(), input_info, {}, Expected2) verify_model(Roll3(), input_info, {}, Expected3) From 699ae028f25d0b60937a7e9a8e3dd43cf5147b09 Mon Sep 17 00:00:00 2001 From: Pratheesh Date: Wed, 16 Apr 2025 08:19:36 +0000 Subject: [PATCH 6/7] lint check --- tests/python/relax/test_frontend_from_fx.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index 93240d6708c9..c52255638072 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -3586,7 +3586,7 @@ def main(inp_0: R.Tensor((4, 2), dtype="int64")) -> R.Tensor((4, 2), dtype="int6 begin=[R.prim_value(0)], end=[R.prim_value(7)], strides=[R.prim_value(1)], - assume_inbound=False + assume_inbound=False, ) lv2: R.Tensor((1,), dtype="int64") = R.strided_slice( lv, @@ -3594,7 +3594,7 @@ def main(inp_0: R.Tensor((4, 2), dtype="int64")) -> R.Tensor((4, 2), dtype="int6 begin=[R.prim_value(7)], end=[R.prim_value(8)], strides=[R.prim_value(1)], - assume_inbound=False + assume_inbound=False, ) lv3: R.Tensor((8,), dtype="int64") = R.concat((lv2, lv1), axis=0) lv4: R.Tensor((4, 2), dtype="int64") = R.reshape(lv3, R.shape([4, 2])) @@ -3614,7 +3614,7 @@ def main(inp_0: R.Tensor((4, 2), dtype="int64")) -> R.Tensor((4, 2), dtype="int6 begin=[R.prim_value(0)], end=[R.prim_value(1)], strides=[R.prim_value(1)], - assume_inbound=False + assume_inbound=False, ) lv1: R.Tensor((3, 2), dtype="int64") = R.strided_slice( inp_0, @@ -3622,7 +3622,7 @@ def main(inp_0: R.Tensor((4, 2), dtype="int64")) -> R.Tensor((4, 2), dtype="int6 begin=[R.prim_value(1)], end=[R.prim_value(4)], strides=[R.prim_value(1)], - assume_inbound=False + assume_inbound=False, ) lv2: R.Tensor((4, 2), dtype="int64") = R.concat((lv1, lv), axis=0) gv: R.Tensor((4, 2), dtype="int64") = lv2 @@ -3641,7 +3641,7 @@ def main(inp_0: R.Tensor((4, 2), dtype="int64")) -> R.Tensor((4, 2), dtype="int6 begin=[R.prim_value(0)], end=[R.prim_value(2)], strides=[R.prim_value(1)], - assume_inbound=False + assume_inbound=False, ) lv1: R.Tensor((2, 2), dtype="int64") = R.strided_slice( inp_0, @@ -3649,7 +3649,7 @@ def main(inp_0: R.Tensor((4, 2), dtype="int64")) -> R.Tensor((4, 2), dtype="int6 begin=[R.prim_value(2)], end=[R.prim_value(4)], strides=[R.prim_value(1)], - assume_inbound=False + assume_inbound=False, ) lv2: R.Tensor((4, 2), dtype="int64") = R.concat((lv1, lv), axis=0) lv3: R.Tensor((4, 1), dtype="int64") = R.strided_slice( @@ -3658,7 +3658,7 @@ def main(inp_0: R.Tensor((4, 2), dtype="int64")) -> R.Tensor((4, 2), dtype="int6 begin=[R.prim_value(0)], end=[R.prim_value(1)], strides=[R.prim_value(1)], - assume_inbound=False + assume_inbound=False, ) lv4: R.Tensor((4, 1), dtype="int64") = R.strided_slice( lv2, @@ -3666,7 +3666,7 @@ def main(inp_0: R.Tensor((4, 2), dtype="int64")) -> R.Tensor((4, 2), dtype="int6 begin=[R.prim_value(1)], end=[R.prim_value(2)], strides=[R.prim_value(1)], - assume_inbound=False + assume_inbound=False, ) lv5: R.Tensor((4, 2), dtype="int64") = R.concat((lv4, lv3), axis=1) gv: R.Tensor((4, 2), dtype="int64") = lv5 From 41cdd29970cd1fb442e8692f61cba8fb8d53e8a5 Mon Sep 17 00:00:00 2001 From: Pratheesh Date: Wed, 16 Apr 2025 08:46:14 +0000 Subject: [PATCH 7/7] confilct resolved --- .../tvm/relax/frontend/torch/base_fx_graph_translator.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index e528e0b43b48..ae4c918900ec 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -21,14 +21,9 @@ import abc from functools import reduce import math -<<<<<<< HEAD from typing import Callable, Dict, Optional, Tuple, Union, List -======= -from typing import Callable, Dict, Optional, Tuple, Union -import tvm ->>>>>>> 20cb5dd08 (add op support for roll op) -from tvm import relax +from tvm import relax, tir class BaseFXGraphImporter(metaclass=abc.ABCMeta): @@ -1179,7 +1174,7 @@ def _roll(self, node: fx.Node) -> relax.Var: original_shape = self.shape_of(input_tensor) def to_int(val): - if isinstance(val, tvm.tir.IntImm): + if isinstance(val, tir.IntImm): return int(val.value) elif isinstance(val, int): return val