From 70bf442588b103cffb73bfcfa835fa4a3fabdfc6 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Sat, 14 Sep 2024 21:33:41 +0900 Subject: [PATCH 1/4] cleanup `_mean()` --- python/tvm/relax/frontend/torch/fx_translator.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 1c4796a533a4..e7f712368ad2 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -884,6 +884,15 @@ def _unbind(self, node: fx.Node) -> relax.Var: ret.append(self.block_builder.emit(relax.op.squeeze(split[i], axis=dim))) return self.block_builder.emit(relax.Tuple(ret)) + ########## Statistical ########## + + def _mean(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + x = args[0] + dim = args[1] if len(node.args) > 1 else node.kwargs.get("dim", None) + keepdim = args[2] if len(node.args) > 2 else node.kwargs.get("keepdim", False) + return self.block_builder.emit(relax.op.mean(x, dim, keepdims=keepdim)) + ########## Creation ########## def _arange(self, node: fx.Node) -> relax.Var: @@ -1031,13 +1040,6 @@ def _sum(self, node: fx.Node) -> relax.Var: return self.block_builder.emit(relax.op.sum(args[0], keepdims=keepdim)) return self.block_builder.emit(relax.op.sum(args[0], args[1])) - def _mean(self, node: fx.Node) -> relax.Var: - args = self.retrieve_args(node) - keepdim = node.kwargs["keepdim"] if "keepdim" in node.kwargs else False - if len(args) == 1: - return self.block_builder.emit(relax.op.mean(args[0], keepdims=keepdim)) - return self.block_builder.emit(relax.op.mean(args[0], args[1], keepdims=keepdim)) - ########## DataType ########## def _float(self, node: fx.Node) -> relax.Var: From eff6b68b208612f681f85693515d8ff3ff230273 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Sat, 14 Sep 2024 21:33:41 +0900 Subject: [PATCH 2/4] cleanup `_sum()` --- python/tvm/relax/frontend/torch/fx_translator.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index e7f712368ad2..207b28193e5a 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -893,6 +893,13 @@ def _mean(self, node: fx.Node) -> relax.Var: keepdim = args[2] if len(node.args) > 2 else node.kwargs.get("keepdim", False) return self.block_builder.emit(relax.op.mean(x, dim, keepdims=keepdim)) + def _sum(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + keepdim = node.kwargs["keepdim"] if "keepdim" in node.kwargs else False + if len(args) == 1: + return self.block_builder.emit(relax.op.sum(args[0], keepdims=keepdim)) + return self.block_builder.emit(relax.op.sum(args[0], args[1])) + ########## Creation ########## def _arange(self, node: fx.Node) -> relax.Var: @@ -1031,15 +1038,6 @@ def _full(self, node: fx.Node) -> relax.Var: ) ) - ########## Statistical ########## - - def _sum(self, node: fx.Node) -> relax.Var: - args = self.retrieve_args(node) - keepdim = node.kwargs["keepdim"] if "keepdim" in node.kwargs else False - if len(args) == 1: - return self.block_builder.emit(relax.op.sum(args[0], keepdims=keepdim)) - return self.block_builder.emit(relax.op.sum(args[0], args[1])) - ########## DataType ########## def _float(self, node: fx.Node) -> relax.Var: From ecc6faf22cc935ad54c9f18d1ee513e1b8ebf730 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Sat, 14 Sep 2024 21:33:41 +0900 Subject: [PATCH 3/4] cleanup `_argmax_argmin()` --- .../tvm/relax/frontend/torch/fx_translator.py | 39 +++++++------------ 1 file changed, 13 insertions(+), 26 deletions(-) diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 207b28193e5a..e250573d0b8e 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -900,6 +900,19 @@ def _sum(self, node: fx.Node) -> relax.Var: return self.block_builder.emit(relax.op.sum(args[0], keepdims=keepdim)) return self.block_builder.emit(relax.op.sum(args[0], args[1])) + ########## Search ########## + + def _argmax_argmin(self, op: Callable) -> Callable: + from torch import fx + + def convert(node: fx.Node): + x = self.env[node.args[0]] + dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", None) + keepdim = node.args[2] if len(node.args) > 2 else node.kwargs.get("keepdim", False) + return self.block_builder.emit(op(x, dim, keepdim)) + + return convert + ########## Creation ########## def _arange(self, node: fx.Node) -> relax.Var: @@ -1220,32 +1233,6 @@ def _inplace_masked_fill(self, node: fx.Node) -> relax.Var: self.env[node.args[0]] = output return output - ########## Search ########## - - def _argmax_argmin(self, op: Callable) -> Callable: - from torch import fx - - def convert(node: fx.Node): - x = self.env[node.args[0]] - dim = None - keepdims = False - - if len(node.args) > 1: - dim = node.args[1] - if len(node.args) > 2: - keepdims = node.args[2] - - if "dim" in node.kwargs: - dim = node.kwargs["dim"] - if "keepdim" in node.kwargs: - keepdims = node.kwargs["keepdim"] - if "keepdims" in node.kwargs: - keepdims = node.kwargs["keepdims"] - - return self.block_builder.emit(op(x, dim, keepdims)) - - return convert - ########## Neural Network ########## def _softmax(self, node: fx.Node) -> relax.Var: From cb7c63191be45423619767914841e5a97fe00d7e Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Sat, 14 Sep 2024 21:33:41 +0900 Subject: [PATCH 4/4] cleanup datatype ops --- .../tvm/relax/frontend/torch/fx_translator.py | 52 +++++++++---------- 1 file changed, 26 insertions(+), 26 deletions(-) diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index e250573d0b8e..4dc49d20ff36 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -913,6 +913,32 @@ def convert(node: fx.Node): return convert + ########## DataType ########## + + def _float(self, node: fx.Node) -> relax.Var: + return self.block_builder.emit(relax.op.astype(self.env[node.args[0]], "float32")) + + def _half(self, node: fx.Node) -> relax.Var: + return self.block_builder.emit(relax.op.astype(self.env[node.args[0]], "float16")) + + def _to(self, node: fx.Node) -> relax.Var: + import torch + + x = self.env[node.args[0]] + if len(node.args) == 2: + if isinstance(node.args[1], torch.dtype): + dtype = TorchFXImporter._convert_data_type(node.args[1], self.env) + return self.block_builder.emit(relax.op.astype(x, dtype)) + elif "dtype" in node.kwargs: + dtype = TorchFXImporter._convert_data_type(node.kwargs["dtype"], self.env) + return self.block_builder.emit(relax.op.astype(x, dtype)) + return x + + def _type(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + dtype = TorchFXImporter._convert_data_type(node.args[1], self.env) + return self.block_builder.emit(relax.op.astype(x, dtype)) + ########## Creation ########## def _arange(self, node: fx.Node) -> relax.Var: @@ -1051,32 +1077,6 @@ def _full(self, node: fx.Node) -> relax.Var: ) ) - ########## DataType ########## - - def _float(self, node: fx.Node) -> relax.Var: - return self.block_builder.emit(relax.op.astype(self.env[node.args[0]], "float32")) - - def _half(self, node: fx.Node) -> relax.Var: - return self.block_builder.emit(relax.op.astype(self.env[node.args[0]], "float16")) - - def _type(self, node: fx.Node) -> relax.Var: - x = self.env[node.args[0]] - dtype = TorchFXImporter._convert_data_type(node.args[1], self.env) - return self.block_builder.emit(relax.op.astype(x, dtype)) - - def _to(self, node: fx.Node) -> relax.Var: - import torch - - x = self.env[node.args[0]] - if len(node.args) == 2: - if isinstance(node.args[1], torch.dtype): - dtype = TorchFXImporter._convert_data_type(node.args[1], self.env) - return self.block_builder.emit(relax.op.astype(x, dtype)) - elif "dtype" in node.kwargs: - dtype = TorchFXImporter._convert_data_type(node.kwargs["dtype"], self.env) - return self.block_builder.emit(relax.op.astype(x, dtype)) - return x - ########## Manipulation ########## def _cat(self, node: fx.Node) -> relax.Var: