From 0a2740fd96514bf981734dd9eea2a7f1e6c9ce3b Mon Sep 17 00:00:00 2001 From: deivanayakisankaralingam Date: Thu, 17 Apr 2025 08:56:07 +0000 Subject: [PATCH 1/3] add rsub op support into exported and fx graph frontend --- .../torch/base_fx_graph_translator.py | 10 +++++ .../torch/exported_program_translator.py | 2 + .../tvm/relax/frontend/torch/fx_translator.py | 1 + .../test_frontend_from_exported_program.py | 33 +++++++++++++++++ tests/python/relax/test_frontend_from_fx.py | 37 +++++++++++++++++++ 5 files changed, 83 insertions(+) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 4c9480b58748..50eb2bcb9c37 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -401,6 +401,16 @@ def call_binary_op(op, lhs, rhs): return convert + def _rsub(self, node: fx.Node)-> relax.Var: + args = self.retrieve_args(node) + input = args[0] + other = args[1] + + if isinstance(other, (int, float)): + other = relax.const(other) + + return self.block_builder.emit(relax.op.subtract(other, input)) + ########## Linear Algebra ########## def _linalg_vector_norm(self, node: fx.Node) -> relax.Var: diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index c82a5e2b1100..ff343c498f4a 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -304,6 +304,8 @@ def create_convert_map( "relu_.default": self._unary_op(relax.op.nn.relu), "round.default": self._round, "rsqrt.default": self._unary_op(relax.op.rsqrt), + "rsub.Tensor": self._rsub, + "rsub.Scalar": self._rsub, "selu.default": self._unary_op(relax.op.nn.selu), "sigmoid.default": self._unary_op(relax.op.sigmoid), "sign.default": self._unary_op(relax.op.sign), diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 297529e8bf29..886a23eb1c0f 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -699,6 +699,7 @@ def create_convert_map( "pow": self._binary_op(relax.op.power, operator.pow), "or_": self._binary_op(relax.op.bitwise_or, operator.or_), "rshift": self._binary_op(relax.op.right_shift, operator.rshift), + "rsub": self._rsub, "sub": self._binary_op(relax.op.subtract, operator.sub), "truediv": self._binary_op(relax.op.divide, operator.truediv), "xor": self._binary_op(relax.op.bitwise_xor, operator.xor), diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 26d3d3f7bde2..08ea051c06e7 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -899,6 +899,7 @@ def test_binary3(): torch.randn(10, 10, dtype=torch.float32), torch.randn(10, 10, dtype=torch.float32), ) + example_args2 = (torch.randn(10, 10, dtype=torch.float32),) # Max class Max1(Module): @@ -940,6 +941,38 @@ def main( verify_model(Min1(), example_args1, {}, expected_min1) + # RSub + class RSub1(Module): + def forward(self, x, y): + return torch.rsub(x, y) + + class RSub2(Module): + def forward(self, x): + return torch.rsub(x, 5.0) + + @tvm.script.ir_module + class expected_rsub1: + @R.function + def main(x: R.Tensor((10, 10), dtype="float32"), y: R.Tensor((10, 10), dtype="float32")) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((10, 10), dtype="float32") = R.subtract(y, x) + gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + @tvm.script.ir_module + class expected_rsub2: + @R.function + def main(x: R.Tensor((10, 10), dtype="float32")) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((10, 10), dtype="float32") = R.subtract(R.const(5.0, "float32"), x) + gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + verify_model(RSub1(), example_args1, {}, expected_rsub1) + verify_model(RSub2(), example_args2, {}, expected_rsub2) + def test_batchnorm2d(): class BatchNorm2d(Module): diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index a962de8a3237..bec963c57e20 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -1702,6 +1702,43 @@ def main( verify_model(Binary2(op), input_info2, {}, expected_binary2) +# RSub +def test_rsub(): + input_info1 = [([10, 10], "float32"), ([10, 10], "float32")] + input_info2 = [([10, 10], "float32")] + + class RSub1(Module): + def forward(self, x, y): + return torch.rsub(x, y) + + class RSub2(Module): + def forward(self, x): + return torch.rsub(x, 5.0) + + @tvm.script.ir_module + class expected_rsub1: + @R.function + def main(x: R.Tensor((10, 10), dtype="float32"), y: R.Tensor((10, 10), dtype="float32")) -> R.Tensor((10, 10), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((10, 10), dtype="float32") = R.subtract(y, x) + gv: R.Tensor((10, 10), dtype="float32") = lv + R.output(gv) + return gv + + @tvm.script.ir_module + class expected_rsub2: + @R.function + def main(x: R.Tensor((10, 10), dtype="float32")) -> R.Tensor((10, 10), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((10, 10), dtype="float32") = R.subtract(R.const(5.0, "float32"), x) + gv: R.Tensor((10, 10), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(RSub1(), input_info1, {}, expected_rsub1) + verify_model(RSub2(), input_info2, {}, expected_rsub2) + + def test_size(): input_info = [([1, 3, 10, 10], "float32")] From 745c77c9def4416331359628ce2aaf921e04e4a9 Mon Sep 17 00:00:00 2001 From: deivanayakisankaralingam Date: Thu, 17 Apr 2025 08:57:18 +0000 Subject: [PATCH 2/3] fix trailing whitespace issue --- .../relax/test_frontend_from_exported_program.py | 8 ++++---- tests/python/relax/test_frontend_from_fx.py | 10 +++++----- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 08ea051c06e7..15b2afcf01ba 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -945,11 +945,11 @@ def main( class RSub1(Module): def forward(self, x, y): return torch.rsub(x, y) - + class RSub2(Module): def forward(self, x): return torch.rsub(x, 5.0) - + @tvm.script.ir_module class expected_rsub1: @R.function @@ -959,9 +959,9 @@ def main(x: R.Tensor((10, 10), dtype="float32"), y: R.Tensor((10, 10), dtype="fl gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,) R.output(gv) return gv - + @tvm.script.ir_module - class expected_rsub2: + class expected_rsub2: @R.function def main(x: R.Tensor((10, 10), dtype="float32")) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): with R.dataflow(): diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index bec963c57e20..507ea52ec612 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -1706,15 +1706,15 @@ def main( def test_rsub(): input_info1 = [([10, 10], "float32"), ([10, 10], "float32")] input_info2 = [([10, 10], "float32")] - + class RSub1(Module): def forward(self, x, y): return torch.rsub(x, y) - + class RSub2(Module): def forward(self, x): return torch.rsub(x, 5.0) - + @tvm.script.ir_module class expected_rsub1: @R.function @@ -1726,7 +1726,7 @@ def main(x: R.Tensor((10, 10), dtype="float32"), y: R.Tensor((10, 10), dtype="fl return gv @tvm.script.ir_module - class expected_rsub2: + class expected_rsub2: @R.function def main(x: R.Tensor((10, 10), dtype="float32")) -> R.Tensor((10, 10), dtype="float32"): with R.dataflow(): @@ -1734,7 +1734,7 @@ def main(x: R.Tensor((10, 10), dtype="float32")) -> R.Tensor((10, 10), dtype="fl gv: R.Tensor((10, 10), dtype="float32") = lv R.output(gv) return gv - + verify_model(RSub1(), input_info1, {}, expected_rsub1) verify_model(RSub2(), input_info2, {}, expected_rsub2) From 2b7d9dae201c8fb03cd817ca106d42c183e4f05a Mon Sep 17 00:00:00 2001 From: deivanayakisankaralingam Date: Thu, 17 Apr 2025 09:32:18 +0000 Subject: [PATCH 3/3] fix lint issues in test scripts --- .../relax/frontend/torch/base_fx_graph_translator.py | 12 ++++++------ .../relax/test_frontend_from_exported_program.py | 8 ++++++-- tests/python/relax/test_frontend_from_fx.py | 4 +++- 3 files changed, 15 insertions(+), 9 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 50eb2bcb9c37..1fa21607a59a 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -401,15 +401,15 @@ def call_binary_op(op, lhs, rhs): return convert - def _rsub(self, node: fx.Node)-> relax.Var: + def _rsub(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) - input = args[0] - other = args[1] + lhs = args[0] + rhs = args[1] - if isinstance(other, (int, float)): - other = relax.const(other) + if isinstance(rhs, (int, float)): + rhs = relax.const(rhs) - return self.block_builder.emit(relax.op.subtract(other, input)) + return self.block_builder.emit(relax.op.subtract(rhs, lhs)) ########## Linear Algebra ########## diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 15b2afcf01ba..0cb00d216fdc 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -953,7 +953,9 @@ def forward(self, x): @tvm.script.ir_module class expected_rsub1: @R.function - def main(x: R.Tensor((10, 10), dtype="float32"), y: R.Tensor((10, 10), dtype="float32")) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): + def main( + x: R.Tensor((10, 10), dtype="float32"), y: R.Tensor((10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): with R.dataflow(): lv: R.Tensor((10, 10), dtype="float32") = R.subtract(y, x) gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,) @@ -963,7 +965,9 @@ def main(x: R.Tensor((10, 10), dtype="float32"), y: R.Tensor((10, 10), dtype="fl @tvm.script.ir_module class expected_rsub2: @R.function - def main(x: R.Tensor((10, 10), dtype="float32")) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): + def main( + x: R.Tensor((10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): with R.dataflow(): lv: R.Tensor((10, 10), dtype="float32") = R.subtract(R.const(5.0, "float32"), x) gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,) diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index 507ea52ec612..4e847be317d4 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -1718,7 +1718,9 @@ def forward(self, x): @tvm.script.ir_module class expected_rsub1: @R.function - def main(x: R.Tensor((10, 10), dtype="float32"), y: R.Tensor((10, 10), dtype="float32")) -> R.Tensor((10, 10), dtype="float32"): + def main( + x: R.Tensor((10, 10), dtype="float32"), y: R.Tensor((10, 10), dtype="float32") + ) -> R.Tensor((10, 10), dtype="float32"): with R.dataflow(): lv: R.Tensor((10, 10), dtype="float32") = R.subtract(y, x) gv: R.Tensor((10, 10), dtype="float32") = lv