diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 6bbc9d5de618..4bfdb8c1bc36 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -908,6 +908,20 @@ def _mean(self, node: fx.Node) -> relax.Var: keepdim = args[2] if len(node.args) > 2 else node.kwargs.get("keepdim", False) return self.block_builder.emit(relax.op.mean(x, dim, keepdims=keepdim)) + def _prod(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + x = args[0] + dim = args[1] if len(node.args) > 1 else node.kwargs.get("dim", None) + keepdim = args[2] if len(node.args) > 2 else node.kwargs.get("keepdim", False) + return self.block_builder.emit(relax.op.prod(x, dim, keepdims=keepdim)) + + def _std(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + x = args[0] + dim = args[1] if len(node.args) > 1 else node.kwargs.get("dim", None) + keepdim = args[2] if len(node.args) > 2 else node.kwargs.get("keepdim", False) + return self.block_builder.emit(relax.op.std(x, dim, keepdims=keepdim)) + def _sum(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) keepdim = node.kwargs["keepdim"] if "keepdim" in node.kwargs else False @@ -915,6 +929,13 @@ def _sum(self, node: fx.Node) -> relax.Var: return self.block_builder.emit(relax.op.sum(args[0], keepdims=keepdim)) return self.block_builder.emit(relax.op.sum(args[0], args[1])) + def _var(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + x = args[0] + dim = args[1] if len(node.args) > 1 else node.kwargs.get("dim", None) + keepdim = args[2] if len(node.args) > 2 else node.kwargs.get("keepdim", False) + return self.block_builder.emit(relax.op.variance(x, dim, keepdims=keepdim)) + ########## Search ########## def _argmax_argmin(self, op: Callable) -> Callable: diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 98de2e114bcd..022a7bffea80 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -777,7 +777,10 @@ def create_convert_map( "lerp": self._lerp, # statistical "mean": self._mean, + "prod": self._prod, + "std": self._std, "sum": self._sum, + "var": self._var, # search "argmax": self._argmax_argmin(relax.op.argmax), "argmin": self._argmax_argmin(relax.op.argmin), diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index 962a6accff91..726ff6f8e81d 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -4255,5 +4255,65 @@ def main( ) +def test_std(): + class Std(Module): + def forward(self, x): + return torch.std(x) + + @tvm.script.ir_module + class Expected: + @R.function + def main( + inp_0: R.Tensor((5, 3), dtype="float32"), + ) -> R.Tensor((), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((), dtype="float32") = R.std(inp_0, axis=None, keepdims=False) + gv: R.Tensor((), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Std(), [([5, 3], "float32")], {}, Expected) + + +def test_var(): + class Var(Module): + def forward(self, x): + return torch.var(x) + + @tvm.script.ir_module + class Expected: + @R.function + def main( + inp_0: R.Tensor((5, 3), dtype="float32"), + ) -> R.Tensor((), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((), dtype="float32") = R.variance(inp_0, axis=None, keepdims=False) + gv: R.Tensor((), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Var(), [([5, 3], "float32")], {}, Expected) + + +def test_prod(): + class Prod(Module): + def forward(self, x): + return torch.prod(x) + + @tvm.script.ir_module + class Expected: + @R.function + def main( + inp_0: R.Tensor((5, 3), dtype="float32"), + ) -> R.Tensor((), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((), dtype="float32") = R.prod(inp_0, axis=None, keepdims=False) + gv: R.Tensor((), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Prod(), [([5, 3], "float32")], {}, Expected) + + if __name__ == "__main__": tvm.testing.main()