diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 1c4796a533a4..4dc49d20ff36 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -884,6 +884,61 @@ def _unbind(self, node: fx.Node) -> relax.Var: ret.append(self.block_builder.emit(relax.op.squeeze(split[i], axis=dim))) return self.block_builder.emit(relax.Tuple(ret)) + ########## Statistical ########## + + def _mean(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + x = args[0] + dim = args[1] if len(node.args) > 1 else node.kwargs.get("dim", None) + keepdim = args[2] if len(node.args) > 2 else node.kwargs.get("keepdim", False) + return self.block_builder.emit(relax.op.mean(x, dim, keepdims=keepdim)) + + def _sum(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + keepdim = node.kwargs["keepdim"] if "keepdim" in node.kwargs else False + if len(args) == 1: + return self.block_builder.emit(relax.op.sum(args[0], keepdims=keepdim)) + return self.block_builder.emit(relax.op.sum(args[0], args[1])) + + ########## Search ########## + + def _argmax_argmin(self, op: Callable) -> Callable: + from torch import fx + + def convert(node: fx.Node): + x = self.env[node.args[0]] + dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", None) + keepdim = node.args[2] if len(node.args) > 2 else node.kwargs.get("keepdim", False) + return self.block_builder.emit(op(x, dim, keepdim)) + + return convert + + ########## DataType ########## + + def _float(self, node: fx.Node) -> relax.Var: + return self.block_builder.emit(relax.op.astype(self.env[node.args[0]], "float32")) + + def _half(self, node: fx.Node) -> relax.Var: + return self.block_builder.emit(relax.op.astype(self.env[node.args[0]], "float16")) + + def _to(self, node: fx.Node) -> relax.Var: + import torch + + x = self.env[node.args[0]] + if len(node.args) == 2: + if isinstance(node.args[1], torch.dtype): + dtype = TorchFXImporter._convert_data_type(node.args[1], self.env) + return self.block_builder.emit(relax.op.astype(x, dtype)) + elif "dtype" in node.kwargs: + dtype = TorchFXImporter._convert_data_type(node.kwargs["dtype"], self.env) + return self.block_builder.emit(relax.op.astype(x, dtype)) + return x + + def _type(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + dtype = TorchFXImporter._convert_data_type(node.args[1], self.env) + return self.block_builder.emit(relax.op.astype(x, dtype)) + ########## Creation ########## def _arange(self, node: fx.Node) -> relax.Var: @@ -1022,48 +1077,6 @@ def _full(self, node: fx.Node) -> relax.Var: ) ) - ########## Statistical ########## - - def _sum(self, node: fx.Node) -> relax.Var: - args = self.retrieve_args(node) - keepdim = node.kwargs["keepdim"] if "keepdim" in node.kwargs else False - if len(args) == 1: - return self.block_builder.emit(relax.op.sum(args[0], keepdims=keepdim)) - return self.block_builder.emit(relax.op.sum(args[0], args[1])) - - def _mean(self, node: fx.Node) -> relax.Var: - args = self.retrieve_args(node) - keepdim = node.kwargs["keepdim"] if "keepdim" in node.kwargs else False - if len(args) == 1: - return self.block_builder.emit(relax.op.mean(args[0], keepdims=keepdim)) - return self.block_builder.emit(relax.op.mean(args[0], args[1], keepdims=keepdim)) - - ########## DataType ########## - - def _float(self, node: fx.Node) -> relax.Var: - return self.block_builder.emit(relax.op.astype(self.env[node.args[0]], "float32")) - - def _half(self, node: fx.Node) -> relax.Var: - return self.block_builder.emit(relax.op.astype(self.env[node.args[0]], "float16")) - - def _type(self, node: fx.Node) -> relax.Var: - x = self.env[node.args[0]] - dtype = TorchFXImporter._convert_data_type(node.args[1], self.env) - return self.block_builder.emit(relax.op.astype(x, dtype)) - - def _to(self, node: fx.Node) -> relax.Var: - import torch - - x = self.env[node.args[0]] - if len(node.args) == 2: - if isinstance(node.args[1], torch.dtype): - dtype = TorchFXImporter._convert_data_type(node.args[1], self.env) - return self.block_builder.emit(relax.op.astype(x, dtype)) - elif "dtype" in node.kwargs: - dtype = TorchFXImporter._convert_data_type(node.kwargs["dtype"], self.env) - return self.block_builder.emit(relax.op.astype(x, dtype)) - return x - ########## Manipulation ########## def _cat(self, node: fx.Node) -> relax.Var: @@ -1220,32 +1233,6 @@ def _inplace_masked_fill(self, node: fx.Node) -> relax.Var: self.env[node.args[0]] = output return output - ########## Search ########## - - def _argmax_argmin(self, op: Callable) -> Callable: - from torch import fx - - def convert(node: fx.Node): - x = self.env[node.args[0]] - dim = None - keepdims = False - - if len(node.args) > 1: - dim = node.args[1] - if len(node.args) > 2: - keepdims = node.args[2] - - if "dim" in node.kwargs: - dim = node.kwargs["dim"] - if "keepdim" in node.kwargs: - keepdims = node.kwargs["keepdim"] - if "keepdims" in node.kwargs: - keepdims = node.kwargs["keepdims"] - - return self.block_builder.emit(op(x, dim, keepdims)) - - return convert - ########## Neural Network ########## def _softmax(self, node: fx.Node) -> relax.Var: