From a193da0421e5a21e66f31864418d5c8269fd1e24 Mon Sep 17 00:00:00 2001 From: Pratheesh Date: Fri, 25 Apr 2025 09:55:30 +0000 Subject: [PATCH 1/6] add op support for zeros_like and fill_ --- .../torch/base_fx_graph_translator.py | 12 +++++ .../torch/exported_program_translator.py | 2 + .../tvm/relax/frontend/torch/fx_translator.py | 13 ++---- .../test_frontend_from_exported_program.py | 45 +++++++++++++++++++ tests/python/relax/test_frontend_from_fx.py | 24 +++++++++- 5 files changed, 84 insertions(+), 12 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 3e81ff1f0bfe..0beae572aa05 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -1426,6 +1426,14 @@ def _fill(self, node: fx.Node) -> relax.Var: dtype = x.struct_info.dtype value = args[1] if isinstance(args[1], relax.Expr) else relax.const(args[1], dtype) return self.block_builder.emit(relax.op.full(x.struct_info.shape, value, dtype)) + + def _fill_inplace(self, node: fx.Node) -> relax.Var: + target_tensor = self.env[node.args[0]] + fill_value = relax.const(node.args[1]) + dtype = target_tensor.struct_info.dtype + filled_tensor = self.block_builder.emit(relax.op.full_like(target_tensor, fill_value, dtype)) + self.env[node.args[0]] = filled_tensor + return filled_tensor def _full(self, node: fx.Node) -> relax.Var: import torch @@ -1639,6 +1647,10 @@ def _zeros_inplace(self, node: fx.Node) -> relax.Var: output = self.block_builder.emit(relax.op.zeros_like(x)) self.env[node.args[0]] = output return output + + def _zeros_like(self, node: fx.node) -> relax.Var: + x = self.env[node.args[0]] + return self.block_builder.emit(relax.op.zeros_like(x)) @abc.abstractmethod def create_convert_map( diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index a3ab575c4b78..1c384d81ad58 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -471,6 +471,7 @@ def create_convert_map( "eye.default": self._eye, "eye.m": self._eye, "fill.Scalar": self._fill, + "fill_.Scalar": self._fill_inplace, "full.default": self._full, "full_like.default": self._full_like, "index_select.default": self._index_select, @@ -485,6 +486,7 @@ def create_convert_map( ), "zero_.default": self._zeros_inplace, "zeros.default": self._zeros, + "zeros_like.default": self._zeros_like, # datatype "to.dtype": self._to, "to.dtype_layout": self._to, diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 18dba2d988f2..48e072eca99f 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -515,15 +515,6 @@ def _size(self, node: fx.Node) -> relax.Expr: ########## Creation ########## - def _inplace_fill(self, node: fx.Node) -> relax.Var: - args = self.retrieve_args(node) - x = args[0] - dtype = x.struct_info.dtype - value = args[1] if isinstance(args[1], relax.Expr) else relax.const(args[1], dtype) - filled = self.block_builder.emit(relax.op.full(x.struct_info.shape, value, dtype)) - self.env[node.args[0]] = filled - return filled - def _inplace_copy(self, node: fx.Node) -> relax.Var: src = self.env[node.args[1]] self.env[node.args[0]] = src @@ -828,7 +819,8 @@ def create_convert_map( "clone": lambda node: self.env[node.args[0]], "empty": self._empty, "empty_like": self._empty_like, - "fill_": self._inplace_fill, + "fill": self._fill, + "fill_": self._fill_inplace, "full": self._full, "index_select": self._index_select, "masked_fill_": self._inplace_masked_fill, @@ -842,6 +834,7 @@ def create_convert_map( ), "tensor": self._tensor, "zero_": self._zeros_inplace, + "zeros_like": self._zeros_like, "copy_": self._inplace_copy, # datatype "astype": self._type, diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index e3b6f4ad9c17..299bc28682e8 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -3647,6 +3647,30 @@ def main( verify_model(Fill(), example_args, {}, Expected) +def test_fill_inplace(): + class FillInplace(Module): + def forward(self, input: torch.Tensor): + input.fill_(1.5) # In-place operation + return input + + @tvm.script.ir_module + class Expected: + @R.function + def main( + inp_0: R.Tensor((10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((10, 10), dtype="float32") = R.full_like( + inp_0, R.const(1.5, "float32"), dtype="void" + ) + gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(10, 10, dtype=torch.float32),) + verify_model(FillInplace(), example_args, {}, Expected) + + def test_masked_fill(): class Masked_Fill(Module): def forward(self, input: torch.Tensor, mask: torch.Tensor): @@ -4014,6 +4038,27 @@ def main( verify_model(Zeros(), example_args, {}, Expected) +def test_zeros_like(): + class ZerosLike(Module): + def forward(self, input): + return torch.zeros_like(input) + + @tvm.script.ir_module + class Expected: + @R.function + def main( + input: R.Tensor((128, 128), dtype="float32") + ) -> R.Tuple(R.Tensor((128, 128), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((128, 128), dtype="float32") = R.zeros_like(input, dtype="void") + gv: R.Tuple(R.Tensor((128, 128), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.rand(128, 128, dtype=torch.float32),) + verify_model(ZerosLike(), example_args, {}, Expected) + + def test_type_as(): class TypeAs(Module): def forward(self, input, other): diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index 4003202d4f55..b83a386054a7 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -3203,8 +3203,8 @@ class Expected: @R.function def main(inp_0: R.Tensor((10, 10), dtype="float32")) -> R.Tensor((10, 10), dtype="float32"): with R.dataflow(): - lv: R.Tensor((10, 10), dtype="float32") = R.full( - R.shape([10, 10]), R.const(1.5, "float32"), dtype="float32" + lv: R.Tensor((10, 10), dtype="float32") = R.full_like( + inp_0, R.const(1.5, "float32"), dtype="void" ) gv: R.Tensor((10, 10), dtype="float32") = lv R.output(gv) @@ -4717,6 +4717,26 @@ def main( verify_model(ZeroInplace(), [([128, 128], "float32")], {}, Expected) +def test_zeros_like(): + class ZerosLike(Module): + def forward(self, data): + return torch.zeros_like(data) + + @tvm.script.ir_module + class Expected: + @R.function + def main( + inp_0: R.Tensor((128, 128), dtype="float32") + ) -> R.Tensor((128, 128), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((128, 128), dtype="float32") = R.zeros_like(inp_0, dtype="void") + gv: R.Tensor((128, 128), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(ZerosLike(), [([128, 128], "float32")], {}, Expected) + + def test_type_as(): class TypeAs(Module): def forward(self, data, other): From ec551875bd13653d8dee393e75c246f6980b288b Mon Sep 17 00:00:00 2001 From: Pratheesh Date: Sat, 26 Apr 2025 04:53:31 +0000 Subject: [PATCH 2/6] fixing whitespace issues --- .../tvm/relax/frontend/torch/base_fx_graph_translator.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 0beae572aa05..04a812a4e8d8 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -1426,12 +1426,14 @@ def _fill(self, node: fx.Node) -> relax.Var: dtype = x.struct_info.dtype value = args[1] if isinstance(args[1], relax.Expr) else relax.const(args[1], dtype) return self.block_builder.emit(relax.op.full(x.struct_info.shape, value, dtype)) - + def _fill_inplace(self, node: fx.Node) -> relax.Var: target_tensor = self.env[node.args[0]] fill_value = relax.const(node.args[1]) dtype = target_tensor.struct_info.dtype - filled_tensor = self.block_builder.emit(relax.op.full_like(target_tensor, fill_value, dtype)) + filled_tensor = self.block_builder.emit( + relax.op.full_like(target_tensor, fill_value, dtype) + ) self.env[node.args[0]] = filled_tensor return filled_tensor @@ -1647,7 +1649,7 @@ def _zeros_inplace(self, node: fx.Node) -> relax.Var: output = self.block_builder.emit(relax.op.zeros_like(x)) self.env[node.args[0]] = output return output - + def _zeros_like(self, node: fx.node) -> relax.Var: x = self.env[node.args[0]] return self.block_builder.emit(relax.op.zeros_like(x)) From 05034edfcc377f1f193329a35916bb69c5efd04f Mon Sep 17 00:00:00 2001 From: Pratheesh Date: Sat, 26 Apr 2025 14:13:56 +0000 Subject: [PATCH 3/6] unity issue --- tests/python/relax/test_frontend_from_exported_program.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 299bc28682e8..5a698b5439dd 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -3661,7 +3661,7 @@ def main( ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): with R.dataflow(): lv: R.Tensor((10, 10), dtype="float32") = R.full_like( - inp_0, R.const(1.5, "float32"), dtype="void" + inp_0, R.const(1.5, "float32"), dtype="float32" ) gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,) R.output(gv) From d5c36be52603d629961b565cac67aa08b5cf62cd Mon Sep 17 00:00:00 2001 From: Pratheesh Date: Sun, 27 Apr 2025 06:59:22 +0000 Subject: [PATCH 4/6] solved datatype issue --- tests/python/relax/test_frontend_from_fx.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index b83a386054a7..eaafc7a77cd6 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -3204,7 +3204,7 @@ class Expected: def main(inp_0: R.Tensor((10, 10), dtype="float32")) -> R.Tensor((10, 10), dtype="float32"): with R.dataflow(): lv: R.Tensor((10, 10), dtype="float32") = R.full_like( - inp_0, R.const(1.5, "float32"), dtype="void" + inp_0, R.const(1.5, "float32"), dtype="float32" ) gv: R.Tensor((10, 10), dtype="float32") = lv R.output(gv) From 577d0ac89fbdef97d439aa6cb798ad22a27cba39 Mon Sep 17 00:00:00 2001 From: Pratheesh Date: Mon, 28 Apr 2025 01:46:03 +0000 Subject: [PATCH 5/6] unity issue --- .../frontend/torch/base_fx_graph_translator.py | 17 ++++++++--------- .../torch/exported_program_translator.py | 2 +- .../tvm/relax/frontend/torch/fx_translator.py | 2 +- .../test_frontend_from_exported_program.py | 16 +++++++++------- tests/python/relax/test_frontend_from_fx.py | 4 ++-- 5 files changed, 21 insertions(+), 20 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 04a812a4e8d8..a2f50bf9a98d 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -1427,15 +1427,14 @@ def _fill(self, node: fx.Node) -> relax.Var: value = args[1] if isinstance(args[1], relax.Expr) else relax.const(args[1], dtype) return self.block_builder.emit(relax.op.full(x.struct_info.shape, value, dtype)) - def _fill_inplace(self, node: fx.Node) -> relax.Var: - target_tensor = self.env[node.args[0]] - fill_value = relax.const(node.args[1]) - dtype = target_tensor.struct_info.dtype - filled_tensor = self.block_builder.emit( - relax.op.full_like(target_tensor, fill_value, dtype) - ) - self.env[node.args[0]] = filled_tensor - return filled_tensor + def _inplace_fill(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + x = args[0] + dtype = x.struct_info.dtype + value = args[1] if isinstance(args[1], relax.Expr) else relax.const(args[1], dtype) + filled = self.block_builder.emit(relax.op.full(x.struct_info.shape, value, dtype)) + self.env[node.args[0]] = filled + return filled def _full(self, node: fx.Node) -> relax.Var: import torch diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 1c384d81ad58..df37a5b45085 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -471,7 +471,7 @@ def create_convert_map( "eye.default": self._eye, "eye.m": self._eye, "fill.Scalar": self._fill, - "fill_.Scalar": self._fill_inplace, + "fill_.Scalar": self._inplace_fill, "full.default": self._full, "full_like.default": self._full_like, "index_select.default": self._index_select, diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 48e072eca99f..492416e97f7c 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -820,7 +820,7 @@ def create_convert_map( "empty": self._empty, "empty_like": self._empty_like, "fill": self._fill, - "fill_": self._fill_inplace, + "fill_": self._inplace_fill, "full": self._full, "index_select": self._index_select, "masked_fill_": self._inplace_masked_fill, diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 5a698b5439dd..cc03bcaa0cbb 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -3650,24 +3650,26 @@ def main( def test_fill_inplace(): class FillInplace(Module): def forward(self, input: torch.Tensor): - input.fill_(1.5) # In-place operation + input.fill_(42.0) return input @tvm.script.ir_module class Expected: @R.function def main( - inp_0: R.Tensor((10, 10), dtype="float32") - ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): + x: R.Tensor((2, 3), dtype="float32") + ) -> R.Tuple(R.Tensor((2, 3), dtype="float32")): with R.dataflow(): - lv: R.Tensor((10, 10), dtype="float32") = R.full_like( - inp_0, R.const(1.5, "float32"), dtype="float32" + lv: R.Tensor((2, 3), dtype="float32") = R.full( + R.shape([2, 3]), + R.const(42.0, "float32"), + dtype="float32" ) - gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,) + gv: R.Tuple(R.Tensor((2, 3), dtype="float32")) = (lv,) R.output(gv) return gv - example_args = (torch.randn(10, 10, dtype=torch.float32),) + example_args = (torch.randn(2, 3, dtype=torch.float32),) verify_model(FillInplace(), example_args, {}, Expected) diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index eaafc7a77cd6..c6f4c40522f2 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -3203,8 +3203,8 @@ class Expected: @R.function def main(inp_0: R.Tensor((10, 10), dtype="float32")) -> R.Tensor((10, 10), dtype="float32"): with R.dataflow(): - lv: R.Tensor((10, 10), dtype="float32") = R.full_like( - inp_0, R.const(1.5, "float32"), dtype="float32" + lv: R.Tensor((10, 10), dtype="float32") = R.full( + R.shape([10, 10]), R.const(1.5, "float32"), dtype="float32" ) gv: R.Tensor((10, 10), dtype="float32") = lv R.output(gv) From a36ba62993af8e6d3d78dd383d001e1e0a669e59 Mon Sep 17 00:00:00 2001 From: Pratheesh Date: Mon, 28 Apr 2025 04:03:18 +0000 Subject: [PATCH 6/6] lint error --- tests/python/relax/test_frontend_from_exported_program.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index cc03bcaa0cbb..d1f0e5767aa3 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -3661,9 +3661,7 @@ def main( ) -> R.Tuple(R.Tensor((2, 3), dtype="float32")): with R.dataflow(): lv: R.Tensor((2, 3), dtype="float32") = R.full( - R.shape([2, 3]), - R.const(42.0, "float32"), - dtype="float32" + R.shape([2, 3]), R.const(42.0, "float32"), dtype="float32" ) gv: R.Tuple(R.Tensor((2, 3), dtype="float32")) = (lv,) R.output(gv)