From 01b274575c53a09b03e55b04fea331485a985f1a Mon Sep 17 00:00:00 2001 From: Shushi Hong <820958424@qq.com> Date: Sat, 29 Mar 2025 17:24:16 +0800 Subject: [PATCH 1/4] Update fx_translator.py --- python/tvm/relax/frontend/torch/fx_translator.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 022a7bffea80..c4008a939688 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -62,6 +62,10 @@ def _fetch_attr(self, model, target: str): ########## Unary Ops ########## + def _reciprocal(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + return self.block_builder.emit(relax.op.divide(relax.const(1.0, x.struct_info.dtype), x)) + def _leakyrelu_module(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] module = self.named_modules[node.target] @@ -708,6 +712,7 @@ def create_convert_map( "logical_not": self._unary_op(relax.op.logical_not), "log_softmax": self._log_softmax, "neg": self._unary_op(relax.op.negative), + "reciprocal": self._reciprocal, "relu": self._unary_op(relax.op.nn.relu), "round": self._round, "rsqrt": self._unary_op(relax.op.rsqrt), @@ -784,11 +789,13 @@ def create_convert_map( # search "argmax": self._argmax_argmin(relax.op.argmax), "argmin": self._argmax_argmin(relax.op.argmin), + "where": self._where, # tensor manipulation "cat": self._cat, "chunk": self._chunk, "concat": self._cat, "contiguous": lambda node: self.env[node.args[0]], + "cumprod": self._cumprod, "cumsum": self._cumsum, "expand": self._expand, "expand_as.default": self._expand_as, From 4b50a7f4da4b4d592815be5e0c5a7f955df365e4 Mon Sep 17 00:00:00 2001 From: Shushi Hong <820958424@qq.com> Date: Sat, 29 Mar 2025 17:25:37 +0800 Subject: [PATCH 2/4] Update base_fx_graph_translator.py --- .../frontend/torch/base_fx_graph_translator.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 71554a8a5bab..fe0ae412a228 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -949,6 +949,12 @@ def convert(node: fx.Node): return convert + def _where(self, node: fx.Node) -> relax.Var: + condition = self.env[node.args[0]] + x = self.env[node.args[1]] + y = self.env[node.args[2]] + return self.block_builder.emit(relax.op.where(condition, x, y)) + ########## Manipulation ########## def _cat(self, node: fx.Node) -> relax.Var: @@ -967,6 +973,17 @@ def _chunk(self, node: fx.Node) -> relax.Var: relax.op.split(x=x, indices_or_sections=n_sections, axis=dim) ) + def _cumprod(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + + dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", None) + if "dtype" in node.kwargs: + dtype = self._convert_data_type(str(node.kwargs["dtype"]), self.env) + else: + dtype = None + + return self.block_builder.emit(relax.op.cumprod(x, dim, dtype)) + def _cumsum(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] From 7b9554abd0063f4460f046bf85c6f41024af8a9a Mon Sep 17 00:00:00 2001 From: Shushi Hong <820958424@qq.com> Date: Sat, 29 Mar 2025 17:27:23 +0800 Subject: [PATCH 3/4] Update test_frontend_from_fx.py --- tests/python/relax/test_frontend_from_fx.py | 61 +++++++++++++++++++++ 1 file changed, 61 insertions(+) diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index 726ff6f8e81d..ac58ba65a511 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -2339,6 +2339,25 @@ def main( verify_model(LogSoftmax(), input_info, {}, expected_log_softmax) verify_model(LogSoftmax2(), input_info, {}, expected_log_softmax) + # reciprocal + class Reciprocal(Module): + def forward(self, input): + return torch.reciprocal(input) + + @tvm.script.ir_module + class expected_reciprocal: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.divide(R.const(1.0, "float32"), input_1) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Reciprocal(), input_info, {}, expected_reciprocal) + # relu class ReLU0(Module): def __init__(self): @@ -4315,5 +4334,47 @@ def main( verify_model(Prod(), [([5, 3], "float32")], {}, Expected) +def test_cumprod(): + class Cumprod(Module): + def forward(self, x): + return torch.cumprod(x, 0) + + @tvm.script.ir_module + class Expected: + @R.function + def main( + inp_0: R.Tensor((5, 3), dtype="float32"), + ) -> R.Tensor((5, 3), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((5, 3), dtype="float32") = R.cumprod(inp_0, axis=0, exclusive=False) + gv: R.Tensor((5, 3), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Cumprod(), [([5, 3], "float32")], {}, Expected) + + +def test_where(): + class Where(Module): + def forward(self, condition, x, y): + return torch.where(condition, x, y) + + @tvm.script.ir_module + class Expected: + @R.function + def main( + inp_0: R.Tensor((5, 3), dtype="bool"), + inp_1: R.Tensor((5, 3), dtype="float32"), + inp_2: R.Tensor((5, 3), dtype="float32"), + ) -> R.Tensor((5, 3), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((5, 3), dtype="float32") = R.where(inp_0, inp_1, inp_2) + gv: R.Tensor((5, 3), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Where(), [([5, 3], "bool"), ([5, 3], "float32"), ([5, 3], "float32")], {}, Expected) + + if __name__ == "__main__": tvm.testing.main() From aef4e15206c4e8720a61977e0f7779810694b268 Mon Sep 17 00:00:00 2001 From: Shushi Hong <820958424@qq.com> Date: Sat, 29 Mar 2025 23:40:52 +0800 Subject: [PATCH 4/4] Update test_frontend_from_fx.py --- tests/python/relax/test_frontend_from_fx.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index ac58ba65a511..b8d7f0b14e5b 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -2351,7 +2351,9 @@ def main( input_1: R.Tensor((1, 3, 10, 10), dtype="float32") ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): with R.dataflow(): - lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.divide(R.const(1.0, "float32"), input_1) + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.divide( + R.const(1.0, "float32"), input_1 + ) gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv R.output(gv) return gv @@ -4373,7 +4375,9 @@ def main( R.output(gv) return gv - verify_model(Where(), [([5, 3], "bool"), ([5, 3], "float32"), ([5, 3], "float32")], {}, Expected) + verify_model( + Where(), [([5, 3], "bool"), ([5, 3], "float32"), ([5, 3], "float32")], {}, Expected + ) if __name__ == "__main__":