From eba773e87592197cc89078c23c4cfb2813545c94 Mon Sep 17 00:00:00 2001 From: Kavin-mcw Date: Mon, 21 Apr 2025 13:02:15 +0530 Subject: [PATCH 1/4] Add support for ones_like,zero_,zeros,type_as,item --- .../torch/base_fx_graph_translator.py | 16 +++ .../torch/exported_program_translator.py | 22 ++++ .../tvm/relax/frontend/torch/fx_translator.py | 6 + .../test_frontend_from_exported_program.py | 124 ++++++++++++++++++ tests/python/relax/test_frontend_from_fx.py | 88 +++++++++++++ 5 files changed, 256 insertions(+) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index ae4c918900ec..f43a6821a9b6 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -1514,6 +1514,12 @@ def _to(self, node: fx.Node) -> relax.Var: return self.block_builder.emit(relax.op.astype(x, dtype)) return x + def _type_as(self, node: fx.Node) -> relax.Var: + input = self.env[node.args[0]] + other = self.env[node.args[1]] + dtype = other.struct_info.dtype + return self.block_builder.emit(relax.op.astype(input, dtype)) + ########## Others ########## def _getitem(self, node: fx.Node) -> relax.Var: @@ -1597,6 +1603,16 @@ def _getitem(self, node: fx.Node) -> relax.Var: else: assert False + def _item(self, node: fx.Node) -> relax.Var: + input = self.env[node.args[0]] + return self.block_builder.emit(relax.op.take(input, relax.const(0, "int64"), axis=0)) + + def _zeros_inplace(self, node: fx.Node) -> relax.Var: + input = self.env[node.args[0]] + output = self.block_builder.emit(relax.op.zeros_like(input)) + self.env[node.args[0]] = output + return output + @abc.abstractmethod def create_convert_map( self, diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 932607287571..c85e9d208ae3 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -253,6 +253,21 @@ def _one_hot(self, node: fx.Node) -> relax.Var: return self.block_builder.emit(relax.op.one_hot(x, on_value, off_value, num_classes, axis)) + def _zeros(self, node: fx.Node) -> relax.Var: + import torch + + args = self.retrieve_args(node) + size = relax.ShapeExpr(args[0] if isinstance(args[0], (list, tuple)) else (args[0],)) + dtype = self._convert_data_type( + node.kwargs.get("dtype", torch.get_default_dtype()), self.env + ) + return self.block_builder.emit( + relax.op.zeros( + size, + dtype + ) + ) + ########## Others ########## def create_convert_map( @@ -462,11 +477,18 @@ def create_convert_map( "new_ones.default": self._new_ones, "one_hot.default": self._one_hot, "ones.default": self._ones, + "ones_like.default": lambda node: self.block_builder.emit( + relax.op.ones_like(self.env[node.args[0]]) + ), + "zero_.default": self._zeros_inplace, + "zeros.default": self._zeros, # datatype "to.dtype": self._to, "to.dtype_layout": self._to, + "type_as.default": self._type_as, # other "getitem": self._getitem, + "item.default": self._item, } def create_input_vars( diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 5a34befb9296..60b83b5d5c2a 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -781,7 +781,11 @@ def create_convert_map( "new_ones": self._new_ones, "ones": self._ones, "one_hot": self._one_hot, + "ones_like": lambda node: self.block_builder.emit( + relax.op.ones_like(self.env[node.args[0]]) + ), "tensor": self._tensor, + "zero_":self._zeros_inplace, # datatype "astype": self._type, "float": self._float, @@ -789,10 +793,12 @@ def create_convert_map( "is_floating_point": self._is_floating_point, "to": self._to, "type": self._type, + "type_as": self._type_as, # other "getattr": self._getattr, "getitem": self._getitem, "sym_size.int": self._sym_size_int, + "item": self._item, } def update_convert_map(self, custom_convert_map: dict): diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 80c0bd5fb4f5..13a4dd0ef8ac 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -3946,6 +3946,106 @@ def main( verify_model(OneHot(), example_args, {}, Expected) +def test_ones_like(): + class OnesLike(Module): + def forward(self, input): + return torch.ones_like(input) + + @tvm.script.ir_module + class Expected: + @R.function + def main( + input: R.Tensor((128, 128), dtype="float32") + ) -> R.Tuple(R.Tensor((128, 128), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((128, 128), dtype="float32") = R.ones_like( + input, dtype="void" + ) + gv: R.Tuple(R.Tensor((128, 128), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.rand(128, 128, dtype=torch.float32),) + + verify_model(OnesLike(), example_args, {}, Expected) + + +def test_zero_inplace(): + class ZeroInplace(Module): + def forward(self, input): + return input.zero_() + + @tvm.script.ir_module + class Expected: + @R.function + def main( + input: R.Tensor((128, 128), dtype="float32") + ) -> R.Tuple(R.Tensor((128, 128), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((128, 128), dtype="float32") = R.zeros_like( + input, dtype="void" + ) + gv: R.Tuple(R.Tensor((128, 128), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.rand(128, 128, dtype=torch.float32),) + + verify_model(ZeroInplace(), example_args, {}, Expected) + + +def test_zeros(): + class Zeros(Module): + def forward(self, input): + return torch.zeros(5, 2) + + @tvm.script.ir_module + class Expected: + @R.function + def main( + input: R.Tensor((128, 128), dtype="float32") + ) -> R.Tuple(R.Tensor((5, 2), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((5, 2), dtype="float32") = R.zeros( + R.shape([5, 2]), dtype="float32" + ) + gv: R.Tuple(R.Tensor((5, 2), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.rand(128, 128, dtype=torch.float32),) + + verify_model(Zeros(), example_args, {}, Expected) + + +def test_type_as(): + class TypeAs(Module): + def forward(self, input, other): + return input.type_as(other) + + @tvm.script.ir_module + class Expected: + @R.function + def main( + input: R.Tensor((128, 128), dtype="float32"), + other: R.Tensor((128, 128), dtype="float16"), + ) -> R.Tuple(R.Tensor((128, 128), dtype="float16")): + with R.dataflow(): + lv: R.Tensor((128, 128), dtype="float16") = R.astype( + input, dtype="float16" + ) + gv: R.Tuple(R.Tensor((128, 128), dtype="float16")) = (lv,) + R.output(gv) + return gv + + example_args = ( + torch.rand(128, 128, dtype=torch.float32), + torch.rand(128, 128, dtype=torch.float16), + ) + + verify_model(TypeAs(), example_args, {}, Expected) + + def test_select(): class Select(Module): def forward(self, input): @@ -4377,5 +4477,29 @@ def main( verify_model(Narrow(), example_args, {}, Expected) +def test_item(): + class Item(Module): + def forward(self,x): + return x.item() + + @tvm.script.ir_module + class Expected: + @R.function + def main( + input: R.Tensor((1,), dtype="float32") + ) -> R.Tuple(R.Tensor((), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((), dtype="float32") = R.take( + input, + R.const(0, "int64"), + axis=0 + ) + gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(1, dtype=torch.float32),) + verify_model(Item(), example_args, {}, Expected) + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index c52255638072..139dda1a2127 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -4504,6 +4504,94 @@ def main( verify_model(EmptyLike(), [([5], "float32")], {}, Expected) +def test_ones_like(): + class OnesLike(Module): + def forward(self, data): + return torch.ones_like(data) + + @tvm.script.ir_module + class Expected: + @R.function + def main( + inp_0: R.Tensor((128, 128), dtype="float32") + ) -> R.Tensor((128, 128), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((128, 128), dtype="float32") = R.ones_like( + inp_0, dtype="void" + ) + gv: R.Tensor((128, 128), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(OnesLike(), [([128, 128], "float32")], {}, Expected) + + +def test_zero_inplace(): + class ZeroInplace(Module): + def forward(self, data): + return data.zero_() + + @tvm.script.ir_module + class Expected: + @R.function + def main( + inp_0: R.Tensor((128, 128), dtype="float32") + ) -> R.Tensor((128, 128), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((128, 128), dtype="float32") = R.zeros_like( + inp_0, dtype="void" + ) + gv: R.Tensor((128, 128), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(ZeroInplace(), [([128, 128], "float32")], {}, Expected) + + +def test_type_as(): + class TypeAs(Module): + def forward(self, data, other): + return data.type_as(other) + + @tvm.script.ir_module + class Expected: + @R.function + def main( + inp_0: R.Tensor((128, 128), dtype="float16"), + inp_1: R.Tensor((128, 128), dtype="float32"), + ) -> R.Tensor((128, 128), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((128, 128), dtype="float32") = R.astype( + inp_0, dtype="float32" + ) + gv: R.Tensor((128, 128), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(TypeAs(), [([128, 128], "float16"), ([128, 128], "float32")], {}, Expected) + + +def test_item(): + class Item(Module): + def forward(self, data): + return data.item() + + @tvm.script.ir_module + class Expected: + @R.function + def main( + inp_0: R.Tensor((1,), dtype="float32") + ) -> R.Tensor((), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((), dtype="float32") = R.take( + inp_0, R.const(0, "int64"), axis=0 + ) + gv: R.Tensor((), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Item(),[([1,],"float32",)],{},Expected,) + def test_numel(): class Numel(Module): def forward(self, data): From 419b5b539ae54436bd833b5d14fdeb69425d122d Mon Sep 17 00:00:00 2001 From: Kavin-mcw Date: Mon, 21 Apr 2025 14:06:13 +0530 Subject: [PATCH 2/4] Fix lint issues --- .../torch/exported_program_translator.py | 7 +---- .../tvm/relax/frontend/torch/fx_translator.py | 2 +- .../test_frontend_from_exported_program.py | 29 +++++-------------- tests/python/relax/test_frontend_from_fx.py | 22 ++++---------- 4 files changed, 16 insertions(+), 44 deletions(-) diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index c85e9d208ae3..6306d261b24a 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -261,12 +261,7 @@ def _zeros(self, node: fx.Node) -> relax.Var: dtype = self._convert_data_type( node.kwargs.get("dtype", torch.get_default_dtype()), self.env ) - return self.block_builder.emit( - relax.op.zeros( - size, - dtype - ) - ) + return self.block_builder.emit(relax.op.zeros(size, dtype)) ########## Others ########## diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 60b83b5d5c2a..8a79d5601abd 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -785,7 +785,7 @@ def create_convert_map( relax.op.ones_like(self.env[node.args[0]]) ), "tensor": self._tensor, - "zero_":self._zeros_inplace, + "zero_": self._zeros_inplace, # datatype "astype": self._type, "float": self._float, diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 13a4dd0ef8ac..353b2999d5b8 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -3958,9 +3958,7 @@ def main( input: R.Tensor((128, 128), dtype="float32") ) -> R.Tuple(R.Tensor((128, 128), dtype="float32")): with R.dataflow(): - lv: R.Tensor((128, 128), dtype="float32") = R.ones_like( - input, dtype="void" - ) + lv: R.Tensor((128, 128), dtype="float32") = R.ones_like(input, dtype="void") gv: R.Tuple(R.Tensor((128, 128), dtype="float32")) = (lv,) R.output(gv) return gv @@ -3982,9 +3980,7 @@ def main( input: R.Tensor((128, 128), dtype="float32") ) -> R.Tuple(R.Tensor((128, 128), dtype="float32")): with R.dataflow(): - lv: R.Tensor((128, 128), dtype="float32") = R.zeros_like( - input, dtype="void" - ) + lv: R.Tensor((128, 128), dtype="float32") = R.zeros_like(input, dtype="void") gv: R.Tuple(R.Tensor((128, 128), dtype="float32")) = (lv,) R.output(gv) return gv @@ -4006,9 +4002,7 @@ def main( input: R.Tensor((128, 128), dtype="float32") ) -> R.Tuple(R.Tensor((5, 2), dtype="float32")): with R.dataflow(): - lv: R.Tensor((5, 2), dtype="float32") = R.zeros( - R.shape([5, 2]), dtype="float32" - ) + lv: R.Tensor((5, 2), dtype="float32") = R.zeros(R.shape([5, 2]), dtype="float32") gv: R.Tuple(R.Tensor((5, 2), dtype="float32")) = (lv,) R.output(gv) return gv @@ -4031,9 +4025,7 @@ def main( other: R.Tensor((128, 128), dtype="float16"), ) -> R.Tuple(R.Tensor((128, 128), dtype="float16")): with R.dataflow(): - lv: R.Tensor((128, 128), dtype="float16") = R.astype( - input, dtype="float16" - ) + lv: R.Tensor((128, 128), dtype="float16") = R.astype(input, dtype="float16") gv: R.Tuple(R.Tensor((128, 128), dtype="float16")) = (lv,) R.output(gv) return gv @@ -4479,21 +4471,15 @@ def main( def test_item(): class Item(Module): - def forward(self,x): + def forward(self, x): return x.item() @tvm.script.ir_module class Expected: @R.function - def main( - input: R.Tensor((1,), dtype="float32") - ) -> R.Tuple(R.Tensor((), dtype="float32")): + def main(input: R.Tensor((1,), dtype="float32")) -> R.Tuple(R.Tensor((), dtype="float32")): with R.dataflow(): - lv: R.Tensor((), dtype="float32") = R.take( - input, - R.const(0, "int64"), - axis=0 - ) + lv: R.Tensor((), dtype="float32") = R.take(input, R.const(0, "int64"), axis=0) gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv,) R.output(gv) return gv @@ -4501,5 +4487,6 @@ def main( example_args = (torch.randn(1, dtype=torch.float32),) verify_model(Item(), example_args, {}, Expected) + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index 139dda1a2127..cd05d2987799 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -4516,9 +4516,7 @@ def main( inp_0: R.Tensor((128, 128), dtype="float32") ) -> R.Tensor((128, 128), dtype="float32"): with R.dataflow(): - lv: R.Tensor((128, 128), dtype="float32") = R.ones_like( - inp_0, dtype="void" - ) + lv: R.Tensor((128, 128), dtype="float32") = R.ones_like(inp_0, dtype="void") gv: R.Tensor((128, 128), dtype="float32") = lv R.output(gv) return gv @@ -4538,9 +4536,7 @@ def main( inp_0: R.Tensor((128, 128), dtype="float32") ) -> R.Tensor((128, 128), dtype="float32"): with R.dataflow(): - lv: R.Tensor((128, 128), dtype="float32") = R.zeros_like( - inp_0, dtype="void" - ) + lv: R.Tensor((128, 128), dtype="float32") = R.zeros_like(inp_0, dtype="void") gv: R.Tensor((128, 128), dtype="float32") = lv R.output(gv) return gv @@ -4561,9 +4557,7 @@ def main( inp_1: R.Tensor((128, 128), dtype="float32"), ) -> R.Tensor((128, 128), dtype="float32"): with R.dataflow(): - lv: R.Tensor((128, 128), dtype="float32") = R.astype( - inp_0, dtype="float32" - ) + lv: R.Tensor((128, 128), dtype="float32") = R.astype(inp_0, dtype="float32") gv: R.Tensor((128, 128), dtype="float32") = lv R.output(gv) return gv @@ -4579,18 +4573,14 @@ def forward(self, data): @tvm.script.ir_module class Expected: @R.function - def main( - inp_0: R.Tensor((1,), dtype="float32") - ) -> R.Tensor((), dtype="float32"): + def main(inp_0: R.Tensor((1,), dtype="float32")) -> R.Tensor((), dtype="float32"): with R.dataflow(): - lv: R.Tensor((), dtype="float32") = R.take( - inp_0, R.const(0, "int64"), axis=0 - ) + lv: R.Tensor((), dtype="float32") = R.take(inp_0, R.const(0, "int64"), axis=0) gv: R.Tensor((), dtype="float32") = lv R.output(gv) return gv - verify_model(Item(),[([1,],"float32",)],{},Expected,) + verify_model(Item(),[([1],"float32",)],{},Expected) def test_numel(): class Numel(Module): From 980bd44b47712ca789412b8678568f619ec5ce4c Mon Sep 17 00:00:00 2001 From: Kavin-mcw Date: Mon, 21 Apr 2025 14:22:37 +0530 Subject: [PATCH 3/4] Fix lint issues --- .../frontend/torch/base_fx_graph_translator.py | 13 +++++++------ .../frontend/torch/exported_program_translator.py | 2 -- tests/python/relax/test_frontend_from_fx.py | 13 ++++++++++++- 3 files changed, 19 insertions(+), 9 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index f43a6821a9b6..fd4c44964a57 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -21,6 +21,7 @@ import abc from functools import reduce import math +from re import X from typing import Callable, Dict, Optional, Tuple, Union, List from tvm import relax, tir @@ -1515,10 +1516,10 @@ def _to(self, node: fx.Node) -> relax.Var: return x def _type_as(self, node: fx.Node) -> relax.Var: - input = self.env[node.args[0]] + x = self.env[node.args[0]] other = self.env[node.args[1]] dtype = other.struct_info.dtype - return self.block_builder.emit(relax.op.astype(input, dtype)) + return self.block_builder.emit(relax.op.astype(x, dtype)) ########## Others ########## @@ -1604,12 +1605,12 @@ def _getitem(self, node: fx.Node) -> relax.Var: assert False def _item(self, node: fx.Node) -> relax.Var: - input = self.env[node.args[0]] - return self.block_builder.emit(relax.op.take(input, relax.const(0, "int64"), axis=0)) + x = self.env[node.args[0]] + return self.block_builder.emit(relax.op.take(x, relax.const(0, "int64"), axis=0)) def _zeros_inplace(self, node: fx.Node) -> relax.Var: - input = self.env[node.args[0]] - output = self.block_builder.emit(relax.op.zeros_like(input)) + x = self.env[node.args[0]] + output = self.block_builder.emit(relax.op.zeros_like(x)) self.env[node.args[0]] = output return output diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 6306d261b24a..94b03e32fed1 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -254,8 +254,6 @@ def _one_hot(self, node: fx.Node) -> relax.Var: return self.block_builder.emit(relax.op.one_hot(x, on_value, off_value, num_classes, axis)) def _zeros(self, node: fx.Node) -> relax.Var: - import torch - args = self.retrieve_args(node) size = relax.ShapeExpr(args[0] if isinstance(args[0], (list, tuple)) else (args[0],)) dtype = self._convert_data_type( diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index cd05d2987799..4d4319f632c6 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -4580,7 +4580,18 @@ def main(inp_0: R.Tensor((1,), dtype="float32")) -> R.Tensor((), dtype="float32" R.output(gv) return gv - verify_model(Item(),[([1],"float32",)],{},Expected) + verify_model( + Item(), + [ + ( + [1], + "float32", + ) + ], + {}, + Expected, + ) + def test_numel(): class Numel(Module): From 14701557d208940d02aca60c22186b3855f8fffd Mon Sep 17 00:00:00 2001 From: Kavin-mcw Date: Mon, 21 Apr 2025 14:31:47 +0530 Subject: [PATCH 4/4] Removed unused import --- python/tvm/relax/frontend/torch/base_fx_graph_translator.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index fd4c44964a57..ff3be883fe7c 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -21,7 +21,6 @@ import abc from functools import reduce import math -from re import X from typing import Callable, Dict, Optional, Tuple, Union, List from tvm import relax, tir