diff --git a/python/tvm/relax/frontend/torch/__init__.py b/python/tvm/relax/frontend/torch/__init__.py index 55da5a456d6a..add488a11426 100644 --- a/python/tvm/relax/frontend/torch/__init__.py +++ b/python/tvm/relax/frontend/torch/__init__.py @@ -18,4 +18,5 @@ PyTorch Frontends for constructing Relax programs, with the model importers """ from .fx_translator import from_fx +from .exported_program_translator import from_exported_program from .dynamo import relax_dynamo, dynamo_capture_subgraphs diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py new file mode 100644 index 000000000000..087d32989e59 --- /dev/null +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -0,0 +1,1742 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# pylint: disable=invalid-name, inconsistent-return-statements, unidiomatic-typecheck +# pylint: disable=import-outside-toplevel +"""PyTorch FX frontend of Relax.""" +from typing import Callable, Dict, List, Optional, Tuple, Union +from functools import reduce + +import tvm +from tvm import relax +import torch + + +class TorchFXImporter: + """An importer from PyTorch FX to Relax.""" + + import torch # type: ignore + from torch import fx + + def __init__(self) -> None: + import torch # type: ignore + from torch import fx + + self.env: Dict[fx.node.Node, relax.Expr] = {} + self.params: Dict[torch.Tensor, relax.Expr] = {} + self.named_modules: Dict[str, torch.Module] = None + self.block_builder: relax.BlockBuilder = None + self.create_convert_map() + + ########## Utilities ########## + def _fetch_attr(self, model, target: str): + import torch # type: ignore + + target_atoms = target.split(".") + attr_itr = model + for i, atom in enumerate(target_atoms): + if not hasattr(attr_itr, atom): + raise RuntimeError( + f"Node referenced non existing target {'.'.join(target_atoms[:i])}" + ) + attr_itr = getattr(attr_itr, atom) + if isinstance(attr_itr, torch.Tensor): + # Its possible for the resulting tensor to be a parameter. + # If so, return the parameter instead. + if attr_itr in self.params: + return self.params[attr_itr] + return TorchFXImporter._convert_torch_tensor_to_relax(attr_itr) + return attr_itr + + @staticmethod + def _convert_data_type(input_type, env: Optional[Dict] = None): + """converts the PyTorch scalar type input_type to a TVM dtype.""" + import torch # type: ignore + + if env is not None and input_type in env: + input_type = env[input_type] + + input_type = input_type.lower() if isinstance(input_type, str) else input_type + if input_type in ["float", "float32", "torch.float32", torch.float32]: + return "float32" + elif input_type in ["float16", "torch.float16", torch.float16]: + return "float16" + elif input_type in ["int64", "torch.int64", torch.int64]: + return "int64" + elif input_type in ["int32", "torch.int32", torch.int32]: + return "int32" + elif input_type in ["bool", "torch.bool", torch.bool]: + return "bool" + else: + raise NotImplementedError("input_type {} is not handled yet".format(input_type)) + + @staticmethod + def _convert_torch_tensor_to_relax(tensor: torch.Tensor) -> relax.Var: + tensor = tensor.detach().cpu() + dtype = TorchFXImporter._convert_data_type(str(tensor.data.dtype)) + return relax.const(tensor.data.numpy(), dtype) + + @staticmethod + def shape_of(tensor): + """Get the shape of a tensor.""" + import torch # type: ignore + + if isinstance(tensor, relax.Expr): + if not isinstance(tensor.struct_info, relax.TensorStructInfo): + raise TypeError("The input Expr of shape_of should be a Tensor") + return tensor.struct_info.shape + elif isinstance(tensor, torch.Tensor): + return tensor.shape + raise ValueError("Unsupported type: {}".format(type(tensor))) + + def retrieve_args(self, node): + return self._retrieve_args(node.args) + + def _retrieve_args(self, node): + from torch import fx + + if isinstance(node, fx.node.Node): + return self.env[node] + elif isinstance(node, tuple): + return tuple(self._retrieve_args(x) for x in node) + elif isinstance(node, list): + return [self._retrieve_args(x) for x in node] + elif isinstance(node, dict): + return {self._retrieve_args(k): self._retrieve_args(v) for k, v in node.items()} + else: + return node + + @staticmethod + def _promote_binary_op_args(lhs, rhs): + if isinstance(lhs, relax.Expr) and isinstance(rhs, relax.Expr): + return lhs, rhs + elif isinstance(lhs, relax.Expr): + assert isinstance(lhs.struct_info, relax.TensorStructInfo) + return lhs, relax.const(rhs, lhs.struct_info.dtype) + elif isinstance(rhs, relax.Expr): + assert isinstance(rhs.struct_info, relax.TensorStructInfo) + return relax.const(lhs, rhs.struct_info.dtype), rhs + else: + assert False + + def _call_binary_op(self, op, lhs, rhs): + lhs, rhs = TorchFXImporter._promote_binary_op_args(lhs, rhs) + return self.block_builder.emit(op(lhs, rhs)) + + ########## Arithmetic ########## + + def _exp(self, node: fx.node.Node) -> relax.Var: + return self.block_builder.emit(relax.op.exp(self.env[node.args[0]])) + + def _sigmoid(self, node: fx.node.Node) -> relax.Var: + return self.block_builder.emit(relax.op.sigmoid(self.env[node.args[0]])) + + def _sqrt(self, node: fx.node.Node) -> relax.Expr: + arg = self.env[node.args[0]] + if isinstance(arg, (int, float)): + arg = relax.const(arg, "float32") + return self.block_builder.emit(relax.op.sqrt(arg)) + + def _rsqrt(self, node: fx.node.Node) -> relax.Expr: + arg = self.env[node.args[0]] + if isinstance(arg, (int, float)): + arg = relax.const(arg, "float32") + return self.block_builder.emit(relax.op.rsqrt(arg)) + + def _round(self, node: fx.node.Node) -> relax.Expr: + if "decimals" in node.kwargs and node.kwargs["decimals"] != 0: + raise ValueError("specifying decimals for round is not supported yet") + arg = self.env[node.args[0]] + return self.block_builder.emit(relax.op.round(arg)) + + def _add(self, node: fx.node.Node) -> relax.Expr: + lhs, rhs = self.retrieve_args(node) + if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var): + return self._call_binary_op(relax.op.add, lhs, rhs) + elif isinstance(lhs, relax.expr.Constant): + return self._call_binary_op( + relax.op.add, lhs, relax.const(rhs, dtype=lhs.struct_info.dtype) + ) + elif isinstance(rhs, relax.expr.Constant): + return self._call_binary_op( + relax.op.add, relax.const(lhs, dtype=rhs.struct_info.dtype), rhs + ) + return lhs + rhs + + def _maximum(self, node: fx.node.Node) -> relax.Expr: + lhs, rhs = self.retrieve_args(node) + if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var): + return self._call_binary_op(relax.op.maximum, lhs, rhs) + + def _floordiv(self, node: fx.node.Node) -> relax.Expr: + lhs, rhs = self.retrieve_args(node) + if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var): + return self._call_binary_op(relax.op.floor_divide, lhs, rhs) + return lhs // rhs + + def _mul(self, node: fx.node.Node) -> relax.Expr: + lhs, rhs = self.retrieve_args(node) + if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var): + return self._call_binary_op(relax.op.multiply, lhs, rhs) + return lhs * rhs + + def _pow(self, node: fx.node.Node) -> relax.Expr: + lhs, rhs = self.retrieve_args(node) + if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var): + return self._call_binary_op(relax.op.power, lhs, rhs) + return lhs**rhs + + def _neg(self, node: fx.node.Node) -> relax.Expr: + x = self.env[node.args[0]] + return self.block_builder.emit(relax.op.negative(x)) + + def _sub(self, node: fx.node.Node) -> relax.Expr: + lhs, rhs = self.retrieve_args(node) + if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var): + return self._call_binary_op(relax.op.subtract, lhs, rhs) + return lhs - rhs + + def _truediv(self, node: fx.node.Node) -> relax.Expr: + lhs, rhs = self.retrieve_args(node) + if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var): + return self._call_binary_op(relax.op.divide, lhs, rhs) + return lhs / rhs + + def _clamp(self, node: fx.node.Node) -> relax.Expr: + args = self.retrieve_args(node) + a_min = node.kwargs["min"] + a_max = node.kwargs["max"] + if not isinstance(a_min, (int, float)): + raise ValueError( + f"TVM only supports constant min value for torch.clamp/clip, " + f"but got {a_min} with type {type(a_min)}" + ) + if not isinstance(a_max, (int, float)): + raise ValueError( + f"TVM only supports constant max value for torch.clamp/clip, " + f"but got {a_max} with type {type(a_max)}" + ) + return self.block_builder.emit(relax.op.clip(args[0], a_min, a_max)) + + def _gelu(self, node: fx.node.Node) -> relax.Expr: + if "approximate" not in node.kwargs: + approximate = "none" + else: + approximate = node.kwargs["approximate"] + if approximate == "none": + return self.block_builder.emit(relax.op.nn.gelu(self.env[node.args[0]])) + elif approximate == "tanh": + return self.block_builder.emit(relax.op.nn.gelu_tanh(self.env[node.args[0]])) + else: + raise KeyError("Unregonized approximate algorithm for gelu: {}.".format(approximate)) + + ########## Compare ########## + + def _lt(self, node: fx.node.Node) -> relax.Expr: + lhs, rhs = self.retrieve_args(node) + return self._call_binary_op(relax.op.less, lhs, rhs) + + def _eq(self, node: fx.node.Node) -> relax.Expr: + lhs, rhs = self.retrieve_args(node) + return self._call_binary_op(relax.op.equal, lhs, rhs) + + ########## Creation ########## + + def _arange(self, node: fx.node.Node) -> relax.Var: + import torch + + start_end_step = [None, None, None] + if "start" in node.kwargs: + start_end_step[0] = node.kwargs["start"] + if "end" in node.kwargs: + start_end_step[1] = node.kwargs["end"] + if "step" in node.kwargs: + start_end_step[2] = node.kwargs["step"] + + if len(node.args) == 1: + assert start_end_step[1] is None + start_end_step[1] = node.args[0] + elif len(node.args) == 2: + assert start_end_step[0] is None + assert start_end_step[1] is None + start_end_step[0] = node.args[0] + start_end_step[1] = node.args[1] + elif len(node.args) == 3: + assert start_end_step[0] is None + assert start_end_step[1] is None + assert start_end_step[2] is None + start_end_step[0] = node.args[0] + start_end_step[1] = node.args[1] + start_end_step[2] = node.args[2] + + if start_end_step[0] is None: + start_end_step[0] = 0 + if start_end_step[2] is None: + start_end_step[2] = 1 + + if "dtype" in node.kwargs: + dtype = TorchFXImporter._convert_data_type(str(node.kwargs["dtype"]), self.env) + elif any([isinstance(x, float) for x in start_end_step]): + dtype = TorchFXImporter._convert_data_type(torch.get_default_dtype()) + else: + dtype = "int64" + start_end_step = [ + self.env[x] if isinstance(x, torch.fx.node.Node) else x for x in start_end_step + ] + return self.block_builder.emit(relax.op.arange(*start_end_step, dtype=dtype)) + + def _empty(self, node: fx.node.Node) -> relax.Var: + dtype = TorchFXImporter._convert_data_type(str(node.kwargs["dtype"]), self.env) + return self.block_builder.emit(relax.op.zeros(node.args, dtype)) + + def _inplace_fill(self, node: fx.node.Node) -> relax.Var: + args = self.retrieve_args(node) + x = args[0] + dtype = x.struct_info.dtype + value = args[1] if isinstance(args[1], relax.Expr) else relax.const(args[1], dtype) + filled = self.block_builder.emit(relax.op.full(x.struct_info.shape, value, dtype)) + self.env[node.args[0]] = filled + return filled + + def _tensor(self, node: fx.node.Node) -> relax.Var: + dtype = node.kwargs["dtype"] if "dtype" in node.kwargs else None + if isinstance(node.args[0], float): + return relax.const(node.args[0], dtype if dtype is not None else "float32") + elif isinstance(node.args[0], int): + return relax.const(node.args[0], dtype if dtype is not None else "int64") + raise ValueError("torch.tensor with value not a float or int is not accepted") + + def _tril_triu(self, op: Callable) -> Callable: + from torch import fx + + def convert(node: fx.node.Node) -> relax.Var: + x = self.env[node.args[0]] + k = node.args[1] if len(node.args) > 1 else 0 + assert isinstance(k, int) + return self.block_builder.emit(op(x, k)) + + return convert + + def _inplace_tril_triu(self, op: Callable) -> Callable: + from torch import fx + + def convert(node: fx.node.Node) -> relax.Var: + x = self.env[node.args[0]] + k = node.args[1] if len(node.args) > 1 else 0 + assert isinstance(k, int) + + mutated = self.block_builder.emit(op(x, k)) + self.env[node.args[0]] = mutated + return mutated + + return convert + + def _new_ones(self, node: fx.node.Node) -> relax.Var: + args = self.retrieve_args(node) + self_var = args[0] + size = args[1:] + if not isinstance(size, (list, tuple)): + size = (size,) + size = relax.ShapeExpr(size) + return self.block_builder.emit( + relax.op.full( + size, + relax.const(1, self_var.struct_info.dtype), + self_var.struct_info.dtype, + ) + ) + + def _ones(self, node: fx.node.Node) -> relax.Var: + import torch + + args = self.retrieve_args(node) + size = args[0] + if not isinstance(size, (list, tuple)): + size = (size,) + size = relax.ShapeExpr(size) + dtype = ( + TorchFXImporter._convert_data_type(str(node.kwargs["dtype"]), self.env) + if "dtype" in node.kwargs + else TorchFXImporter._convert_data_type(torch.get_default_dtype(), self.env) + ) + return self.block_builder.emit( + relax.op.full( + size, + relax.const(1, dtype), + dtype, + ) + ) + + def _full(self, node: fx.node.Node) -> relax.Var: + import torch + + args = self.retrieve_args(node) + size = args[0] + if not isinstance(size, (list, tuple)): + size = (size,) + size = relax.ShapeExpr(size) + dtype = ( + TorchFXImporter._convert_data_type(str(node.kwargs["dtype"]), self.env) + if "dtype" in node.kwargs + else TorchFXImporter._convert_data_type(torch.get_default_dtype(), self.env) + ) + value = args[1] if isinstance(args[1], relax.expr.Constant) else relax.const(args[1], dtype) + return self.block_builder.emit( + relax.op.full( + size, + value, + dtype, + ) + ) + + ########## Statistical ########## + + def _sum(self, node: fx.node.Node) -> relax.Var: + args = self.retrieve_args(node) + keepdim = node.kwargs["keepdim"] if "keepdim" in node.kwargs else False + if len(args) == 1: + return self.block_builder.emit(relax.op.sum(args[0], keepdims=keepdim)) + return self.block_builder.emit(relax.op.sum(args[0], args[1])) + + def _mean(self, node: fx.node.Node) -> relax.Var: + args = self.retrieve_args(node) + keepdim = False + if len(args) > 2: + keepdim = args[2] + if len(args) == 1: + return self.block_builder.emit(relax.op.mean(args[0], keepdims=keepdim)) + return self.block_builder.emit(relax.op.mean(args[0], args[1], keepdims=keepdim)) + + ########## DataType ########## + + def _float(self, node: fx.node.Node) -> relax.Var: + return self.block_builder.emit(relax.op.astype(self.env[node.args[0]], "float32")) + + def _half(self, node: fx.node.Node) -> relax.Var: + return self.block_builder.emit(relax.op.astype(self.env[node.args[0]], "float16")) + + def _type(self, node: fx.node.Node) -> relax.Var: + x = self.env[node.args[0]] + dtype = TorchFXImporter._convert_data_type(node.args[1], self.env) + return self.block_builder.emit(relax.op.astype(x, dtype)) + + def _to(self, node: fx.node.Node) -> relax.Var: + import torch + + x = self.env[node.args[0]] + if len(node.args) == 2: + if isinstance(node.args[1], torch.dtype): + dtype = TorchFXImporter._convert_data_type(node.args[1], self.env) + return self.block_builder.emit(relax.op.astype(x, dtype)) + elif "dtype" in node.kwargs: + dtype = TorchFXImporter._convert_data_type(node.kwargs["dtype"], self.env) + return self.block_builder.emit(relax.op.astype(x, dtype)) + return x + + ########## Linear Algebra ########## + + def _matmul_impl(self, a: relax.Expr, b: relax.Expr): + return self.block_builder.emit(relax.op.linear_algebra.matmul(a, b, out_dtype="float32")) + + def _matmul(self, node: fx.node.Node) -> relax.Var: + args = self.retrieve_args(node) + res = self._matmul_impl( + args[0], + args[1], + ) + return res + + def _addmm(self, node: fx.node.Node) -> relax.Var: + x = self.env[node.args[0]] + y = self.env[node.args[1]] + z = self.env[node.args[2]] + alpha = node.kwargs["alpha"] if "alpha" in node.kwargs else 1 + beta = node.kwargs["beta"] if "beta" in node.kwargs else 1 + + res = None + if alpha != 0: + res = self.block_builder.emit(relax.op.linear_algebra.matmul(y, z, out_dtype="float32")) + if alpha != 1: + dtype = res.struct_info.dtype + res = self.block_builder.emit(relax.op.multiply(res, relax.const(alpha, dtype))) + if beta != 0: + dtype = x.struct_info.dtype + if beta != 1: + bias = self.block_builder.emit(relax.op.multiply(x, relax.const(beta, dtype))) + else: + bias = x + res = bias if res is None else self.block_builder.emit(relax.op.add(bias, res)) + return res + + def _baddbmm(self, node: fx.node.Node) -> relax.Var: + x = self.env[node.args[0]] + a = self.env[node.args[1]] + b = self.env[node.args[2]] + alpha = node.kwargs["alpha"] if "alpha" in node.kwargs else 1 + beta = node.kwargs["beta"] if "beta" in node.kwargs else 1 + + res = None + if alpha != 0: + res = self.block_builder.emit(relax.op.matmul(a, b)) + if alpha != 1: + dtype = res.struct_info.dtype + res = self.block_builder.emit(relax.op.multiply(res, relax.const(alpha, dtype))) + if beta != 0: + dtype = x.struct_info.dtype + if beta != 1: + bias = self.block_builder.emit(relax.op.multiply(x, relax.const(beta, dtype))) + else: + bias = x + res = bias if res is None else self.block_builder.emit(relax.op.add(res, bias)) + return res + + ########## Manipulation ########## + + def _cat(self, node: fx.node.Node) -> relax.Var: + args = self.retrieve_args(node) + axis = args[1] if len(node.args) > 1 else node.kwargs.get("dim", 0) + return self.block_builder.emit(relax.op.concat(args[0], axis=axis)) + + def _expand(self, node: fx.node.Node) -> relax.Var: + args = self.retrieve_args(node) + return self.block_builder.emit(relax.op.broadcast_to(args[0], args[1])) + + def _flatten(self, node: fx.node.Node) -> relax.Var: + x = self.env[node.args[0]] + if node.target in self.named_modules: + module = self.named_modules[node.target] + start_dim = module.start_dim + end_dim = module.end_dim + else: + start_dim = node.args[1] if len(node.args) >= 2 else 0 + end_dim = node.args[2] if len(node.args) == 3 else -1 + shape = self.shape_of(x) + start_dim = start_dim if start_dim >= 0 else len(shape) + start_dim + end_dim = end_dim if end_dim >= 0 else len(shape) + end_dim + flattened = reduce(lambda x, y: x * y, [shape[i] for i in range(start_dim, end_dim + 1)]) + new_shape = ( + [shape[i] for i in range(0, start_dim)] + + [flattened] + + [shape[i] for i in range(end_dim + 1, len(shape))] + ) + return self.block_builder.emit(relax.op.reshape(x, new_shape)) + + def _permute(self, node: fx.node.Node) -> relax.Var: + args = self.retrieve_args(node) + return self.block_builder.emit(relax.op.permute_dims(args[0], args[1])) + + def _reshape(self, node: fx.node.Node) -> relax.Var: + import torch # type: ignore + + args = self.retrieve_args(node) + if isinstance(args[1], (torch.Size, tuple, list)): + return self.block_builder.emit(relax.op.reshape(args[0], tuple(args[1]))) + return self.block_builder.emit(relax.op.reshape(args[0], args[1:])) + + def _split(self, node: fx.node.Node) -> relax.Var: + x = self.env[node.args[0]] + split_size = node.args[1] + if "dim" in node.kwargs: + dim = node.kwargs["dim"] + else: + dim = 0 + n_section = (self.shape_of(x)[dim].value + split_size - 1) // split_size + return self.block_builder.emit(relax.op.split(x, n_section, dim)) + + def _chunk(self, node: fx.node.Node) -> relax.Var: + x = self.env[node.args[0]] + chunks = node.args[1] + + if "dim" in node.kwargs: + dim = node.kwargs["dim"] + elif len(node.args) > 2: + dim = node.args[2] + else: + dim = 0 + return self.block_builder.emit(relax.op.split(x, chunks, dim)) + + def _transpose(self, node: fx.node.Node) -> relax.Var: + args = self.retrieve_args(node) + full_idx = list(range(len(self.shape_of(args[0])))) + full_idx[args[1]], full_idx[args[2]] = full_idx[args[2]], full_idx[args[1]] + return self.block_builder.emit(relax.op.permute_dims(args[0], full_idx)) + + def _squeeze(self, node: fx.node.Node) -> relax.Var: + x = self.env[node.args[0]] + + if "dim" in node.kwargs: + dim = node.kwargs["dim"] + elif len(node.args) > 1: + dim = node.args[1] + else: + dim = None + return self.block_builder.emit(relax.op.squeeze(x, dim)) + + def _cumsum(self, node: fx.node.Node) -> relax.Var: + x = self.env[node.args[0]] + + if "dim" in node.kwargs: + dim = node.kwargs["dim"] + elif len(node.args) > 1: + dim = node.args[1] + else: + dim = None + if "dtype" in node.kwargs: + dtype = TorchFXImporter._convert_data_type(str(node.kwargs["dtype"]), self.env) + else: + dtype = None + if "out" in node.kwargs: + raise ValueError("specifying out for cumsum is not supported yet") + + return self.block_builder.emit(relax.op.cumsum(x, dim, dtype)) + + def _index_select(self, node: fx.node.Node) -> relax.Var: + x = self.env[node.args[0]] + dim = node.args[1] + index = self.env[node.args[2]] + return self.block_builder.emit(relax.op.take(x, index, dim)) + + def _masked_fill(self, node: fx.node.Node) -> relax.Var: + x = self.env[node.args[0]] + mask = self.env[node.args[1]] + value = node.args[2] + rx_value = relax.const(value) + values = self.block_builder.emit(relax.op.full_like(x, rx_value)) + return self.block_builder.emit(relax.op.where(mask, values, x)) + + def _inplace_masked_fill(self, node: fx.node.Node) -> relax.Var: + x = self.env[node.args[0]] + mask = self.env[node.args[1]] + value = node.args[2] + rx_value = relax.const(value) + values = self.block_builder.emit(relax.op.full_like(x, rx_value)) + output = self.block_builder.emit(relax.op.where(mask, values, x)) + self.env[node.args[0]] = output + return output + + ########## Search ########## + + def _argmax_argmin(self, op: Callable) -> Callable: + from torch import fx + + def convert(node: fx.node.Node): + x = self.env[node.args[0]] + dim = None + keepdims = False + + if len(node.args) > 1: + dim = node.args[1] + if len(node.args) > 2: + keepdims = node.args[2] + + if "dim" in node.kwargs: + dim = node.kwargs["dim"] + if "keepdim" in node.kwargs: + keepdims = node.kwargs["keepdim"] + if "keepdims" in node.kwargs: + keepdims = node.kwargs["keepdims"] + + return self.block_builder.emit(op(x, dim, keepdims)) + + return convert + + ########## Neural Network ########## + + def _convolution(self, node: fx.node.Node) -> relax.Var: + inputs = self.retrieve_args(node) + # Use transpose or normal + use_transpose = True if inputs[6] == 1 else False + + data = inputs[0] + weight = inputs[1] + bias = inputs[2] + strides = tuple(inputs[3]) + padding = tuple(inputs[4]) + dilation = tuple(inputs[5]) + + if isinstance(weight, relax.Expr): + inferred_shape = self.infer_shape(weight) + weight_shape = [] + for infer in inferred_shape: + weight_shape.append(infer) + else: + msg = f"Data type {type(weight)} could not be parsed in conv op" + raise AssertionError(msg) + + groups = int(inputs[8]) + + if use_transpose: + channels = weight_shape[1] * groups + in_channels = weight_shape[0] + else: + channels = weight_shape[0] + in_channels = weight_shape[1] + + # Check if this is depth wise convolution + # We need to reshape weight so that Relay could recognize this is depth wise + # weight_shape[1] is always in_channels // groups + # For depthwise, in_channels == groups, so weight_shape[1] == 1 + # If groups > 1 but weight_shape[1] != 1, this is group convolution + if groups > 1 and in_channels == 1: + channel_multiplier = channels // groups + new_weight_shape = (groups, channel_multiplier) + tuple(weight_shape[2:]) + weight = relax.transform.reshape(weight, new_weight_shape) + + kernel_size = weight_shape[2:] + use_bias = isinstance(bias, relax.Expr) + + # We are trying to invoke various relay operations through a single conv_op variable. + # However the function signatures for some operations have additional attributes so we + # pass these in along with the standard ones. + additional_arguments = dict() + + if use_transpose: + if len(kernel_size) == 3: + conv_op = relax.op.nn.conv3d_transpose + elif len(kernel_size) == 2: + conv_op = relax.op.nn.conv2d_transpose + else: + conv_op = relax.op.nn.conv1d_transpose + output_padding = tuple(inputs[7]) + additional_arguments["output_padding"] = output_padding + + else: + if len(kernel_size) == 3: + conv_op = relax.op.nn.conv3d + elif len(kernel_size) == 2: + conv_op = relax.op.nn.conv2d + else: + conv_op = relax.op.nn.conv1d + + if len(kernel_size) == 3: + data_layout = "NCDHW" + kernel_layout = "OIDHW" + if use_transpose: + # Transposed convolutions have IODHW layout. + kernel_layout = "IODHW" + elif len(kernel_size) == 2: + data_layout = "NCHW" + kernel_layout = "OIHW" + if use_transpose: + # Transposed convolutions have IOHW layout. + kernel_layout = "IOHW" + else: + data_layout = "NCW" + kernel_layout = "OIW" + if use_transpose: + # Transposed convolutions have IOW layout. + kernel_layout = "IOW" + + # Conv1d does not currently support grouped convolution so we convert it to conv2d + is_grouped_conv1d = False + if groups > 1 and len(kernel_size) == 1 and not use_transpose: + is_grouped_conv1d = True + conv_op = relax.op.nn.conv2d + kernel_size = [1] + kernel_size + strides = (1,) + strides + padding = (0,) + padding + dilation = (1,) + dilation + data = relax.op.expand_dims(data, axis=2) + weight = relax.op.expand_dims(weight, axis=2) + data_layout = "NCHW" + kernel_layout = "OIHW" + + conv_out = conv_op( + data, + weight, + strides=strides, + padding=padding, + dilation=dilation, + groups=groups, + data_layout=data_layout, + kernel_layout=kernel_layout, + out_dtype='float32' + ) + if use_bias: + assert len(self.shape_of(bias)) == 1 + if data_layout == "NCDHW": + bias = relax.op.reshape(bias, (1, -1, 1, 1, 1)) + elif data_layout == "NCHW": + bias = relax.op.reshape(bias, (1, -1, 1, 1)) + else: + bias = relax.op.reshape(bias, (1, -1, 1)) + + res = relax.op.add(conv_out, bias) + else: + res = conv_out + if is_grouped_conv1d: + # Because we conducted grouped conv1d convolution through conv2d we must + # squeeze the output to get the correct result. + res = relax.op.squeeze(res, axis=[2]) + return res + + def _linear(self, node: fx.node.Node) -> relax.Var: + inputs = self.retrieve_args(node) + bias = inputs[2] + a_shape = self.infer_shape(inputs[0]) + b_shape = self.infer_shape(inputs[1]) + if len(a_shape) == 2 and len(b_shape) == 2: + mm_out = relax.op.nn.dense(inputs[0], inputs[1]) + elif len(b_shape) == 1: + mm_out = self.matmul([inputs[0], inputs[1]]) + else: + mm_out = self.matmul( + [inputs[0], relax.op.transpose(inputs[1], axes=(1, 0))] + ) + if isinstance(bias, relax.Expr): + bias_ndims = len(self.infer_shape_with_prelude(bias)) + if bias_ndims == 1: + return relax.op.nn.bias_add(mm_out, bias, axis=-1) + mm_dtype = self.infer_type_with_prelude(mm_out).dtype + return self.block_builder.emit(relax.op.linear(x, weight, bias, mm_dtype)) + return mm_out + + def _linear_functional(self, node: fx.node.Node) -> relax.Var: + args = self.retrieve_args(node) + x = args[0] + weight = args[1] + bias = args[2] if len(args) > 2 else None + return self.block_builder.emit(relax.op.linear(x, weight, bias, "float32")) + + def _conv1d(self, node: fx.node.Node) -> relax.Var: + x = self.env[node.args[0]] + module = self.named_modules[node.target] + weight = self.params[module.weight] + + conv1d = self.block_builder.emit( + relax.op.nn.conv1d( + x, + weight, + strides=module.stride, + padding=module.padding, + dilation=module.dilation, + groups=module.groups, + data_layout="NCW", + kernel_layout="OIW", + out_dtype="float32", + ) + ) + + if module.bias is None: + return conv1d + + bias = self.params[module.bias] + assert len(self.shape_of(bias)) == 1 + bias = relax.op.reshape(bias, (1, -1, 1)) + + return self.block_builder.emit(relax.op.add(conv1d, bias)) + + def _conv3d(self, node: fx.node.Node) -> relax.Var: + x = self.env[node.args[0]] + module = self.named_modules[node.target] + weight = self.params[module.weight] + + conv3d = self.block_builder.emit( + relax.op.nn.conv3d( + x, + weight, + strides=module.stride, + padding=module.padding, + dilation=module.dilation, + groups=module.groups, + data_layout="NCDHW", + kernel_layout="OIDHW", + out_dtype="float32", + ) + ) + + if module.bias is None: + return conv3d + + bias = self.params[module.bias] + assert len(self.shape_of(bias)) == 1 + bias = relax.op.reshape(bias, (1, -1, 1, 1, 1)) + + return self.block_builder.emit(relax.op.add(conv3d, bias)) + + def _conv2d_impl( + self, + x: relax.Expr, + weight: relax.Expr, + bias: Optional[relax.Expr], + strides: Optional[Tuple], + padding: Optional[Tuple], + dilation: Optional[Tuple], + groups: Optional[Tuple], + ): + conv2d = self.block_builder.emit( + relax.op.nn.conv2d( + x, + weight, + strides=strides, + padding=padding, + dilation=dilation, + groups=groups, + data_layout="NCHW", + kernel_layout="OIHW", + out_dtype="float32", + ) + ) + + if bias is None: + return conv2d + assert len(self.shape_of(bias)) == 1 + bias = relax.op.reshape(bias, (1, -1, 1, 1)) + return self.block_builder.emit(relax.op.add(conv2d, bias)) + + def _conv1d_transpose(self, node: fx.node.Node) -> relax.Var: + x = self.env[node.args[0]] + module = self.named_modules[node.target] + weight = self.params[module.weight] + + conv1d_transpose = self.block_builder.emit( + relax.op.nn.conv1d_transpose( + x, + weight, + strides=module.stride, + padding=module.padding, + dilation=module.dilation, + groups=module.groups, + data_layout="NCW", + kernel_layout="OIW", + out_dtype="float32", + ) + ) + + if module.bias is None: + return conv1d_transpose + + bias = self.params[module.bias] + assert len(self.shape_of(bias)) == 1 + bias = relax.op.reshape(bias, (1, -1, 1)) + + return self.block_builder.emit(relax.op.add(conv1d_transpose, bias)) + + def _conv2d_transpose(self, node: fx.node.Node) -> relax.Var: + x = self.env[node.args[0]] + module = self.named_modules[node.target] + weight = self.params[module.weight] + + conv2d_transpose = self.block_builder.emit( + relax.op.nn.conv2d_transpose( + x, + weight, + strides=module.stride, + padding=module.padding, + dilation=module.dilation, + groups=module.groups, + data_layout="NCHW", + kernel_layout="OIHW", + out_dtype="float32", + ) + ) + + if module.bias is None: + return conv2d_transpose + + bias = self.params[module.bias] + assert len(self.shape_of(bias)) == 1 + bias = relax.op.reshape(bias, (1, -1, 1, 1)) + + return self.block_builder.emit(relax.op.add(conv2d_transpose, bias)) + + def _conv2d(self, node: fx.node.Node) -> relax.Var: + x = self.env[node.args[0]] + module = self.named_modules[node.target] + weight = self.params[module.weight] + bias = None + if module.bias is not None: + bias = self.params[module.bias] + + return self._conv2d_impl( + x, + weight, + bias=bias, + strides=module.stride, + padding=module.padding, + dilation=module.dilation, + groups=module.groups, + ) + + def _conv2d_functional(self, node: fx.node.Node) -> relax.Var: + args = self.retrieve_args(node) + x = args[0] + weight = args[1] + bias = args[2] if len(args) > 2 else None + stride = args[3] if len(args) > 3 else 1 + padding = args[4] if len(args) > 4 else 0 + dilation = args[5] if len(args) > 5 else 1 + groups = args[6] if len(args) > 6 else 1 + return self._conv2d_impl( + x, + weight, + bias=bias, + strides=stride, + padding=padding, + dilation=dilation, + groups=groups, + ) + + def _max_pool2d(self, node: fx.node.Node) -> relax.Var: + x = self.env[node.args[0]] + if node.target in self.named_modules: + module = self.named_modules[node.target] + kernel = module.kernel_size + stride = module.stride + padding = module.padding + dilation = module.dilation + ceil_mode = module.ceil_mode + else: + nargs = len(node.args) + kernel = node.args[1] if nargs > 1 else node.kwargs["kernel_size"] + stride = node.args[2] if nargs > 2 else node.kwargs["stride"] + padding = node.args[3] if nargs > 3 else node.kwargs["padding"] + dilation = node.args[4] if nargs > 4 else node.kwargs["dilation"] + ceil_mode = node.args[5] if nargs > 5 else node.kwargs["ceil_mode"] + + stride = kernel if stride is None else stride + + return self.block_builder.emit( + relax.op.nn.max_pool2d( + x, + pool_size=kernel, + strides=stride, + padding=padding, + dilation=dilation, + layout="NCHW", + ceil_mode=ceil_mode, + ) + ) + + def _avg_pool2d(self, node: fx.node.Node) -> relax.Var: + x = self.env[node.args[0]] + if node.target in self.named_modules: + module = self.named_modules[node.target] + kernel = module.kernel_size + stride = module.stride + padding = module.padding + ceil_mode = module.ceil_mode + else: + nargs = len(node.args) + kernel = node.args[1] if nargs > 1 else node.kwargs["kernel_size"] + if nargs > 2: + stride = node.args[2] + elif "stride" in node.kwargs.keys(): + stride = node.kwargs["stride"] + else: + stride = None + if nargs > 3: + padding = node.args[3] + elif "padding" in node.kwargs.keys(): + padding = node.kwargs["padding"] + else: + padding = 0 + if nargs > 4: + ceil_mode = node.args[4] + elif "ceil_mode" in node.kwargs.keys(): + ceil_mode = node.kwargs["ceil_mode"] + else: + ceil_mode = False + + stride = kernel if stride is None else stride + + return self.block_builder.emit( + relax.op.nn.avg_pool2d( + x, + pool_size=kernel, + strides=stride, + padding=padding, + layout="NCHW", + ceil_mode=ceil_mode, + ) + ) + + def _adaptive_avg_pool2d(self, is_module: bool) -> Callable: + from torch import fx + + def _impl(node: fx.node.Node) -> relax.Var: + if is_module: + module = self.named_modules[node.target] + x = self.env[node.args[0]] + output_size = module.output_size + else: + x = self.env[node.args[0]] + output_size = node.args[1] + return self.block_builder.emit( + relax.op.nn.adaptive_avg_pool2d(x, output_size, layout="NCHW") + ) + + return _impl + + def _softmax(self, node: fx.node.Node) -> relax.Var: + x = self.env[node.args[0]] + if node.target in self.named_modules: + module = self.named_modules[node.target] + dim = module.dim + else: + nargs = len(node.args) + dim = node.args[1] if nargs > 1 else node.kwargs["dim"] + assert dim is not None + return self.block_builder.emit(relax.op.nn.softmax(x, dim)) + + def _log_softmax(self, node: fx.node.Node) -> relax.Var: + x = self.env[node.args[0]] + if node.target in self.named_modules: + module = self.named_modules[node.target] + dim = module.dim + else: + nargs = len(node.args) + dim = node.args[1] if nargs > 1 else node.kwargs["dim"] + assert dim is not None + return self.block_builder.emit(relax.op.nn.log_softmax(x, dim)) + + def _leakyrelu(self, node: fx.node.Node) -> relax.Var: + x = self.env[node.args[0]] + if node.target in self.named_modules: + module = self.named_modules[node.target] + alpha = module.negative_slope + else: + nargs = len(node.args) + alpha = node.args[1] if nargs > 1 else node.kwargs["negative_slope"] + assert alpha is not None + return self.block_builder.emit(relax.op.nn.leakyrelu(x, alpha)) + + def _batch_norm_2d(self, node: fx.node.Node) -> relax.Var: + x = self.env[node.args[0]] + module = self.named_modules[node.target] + weight = self.params[module.weight] + bias = self.params[module.bias] + dtype = TorchFXImporter._convert_data_type(str(module.running_mean.dtype)) + running_mean = relax.const(module.running_mean.cpu().detach().numpy(), dtype) + running_var = relax.const(module.running_var.cpu().detach().numpy(), dtype) + eps = module.eps + + res_tuple = self.block_builder.emit( + relax.op.nn.batch_norm( + x, + weight, + bias, + running_mean, + running_var, + axis=1, + epsilon=eps, + ) + ) + + return self.block_builder.emit(relax.TupleGetItem(res_tuple, 0)) + + def _layer_norm(self, node: fx.node.Node) -> relax.Var: + args = self.retrieve_args(node) + x = args[0] + normalized_shape = args[1] + if len(args) > 2: + gamma = args[2] + else: + gamma = relax.const(torch.ones_like(normalized_shape), x.struct_info.dtype) + + if len(args) > 3: + beta = args[3] + else: + beta = relax.const(torch.zeros_like(normalized_shape), x.struct_info.dtype) + + if len(args) > 4: + eps = args[4] + else: + eps = 1e-05 + + dim_num = len(normalized_shape) + axes = list(range(-dim_num, 0)) + output = relax.op.nn.layer_norm( + x, + gamma, + beta, + axes=axes, + epsilon=eps, + ) + mean = relax.op.mean(x, axes) + rstd = relax.op.rsqrt(relax.op.variance(x)+relax.const(eps)) + tuple_res = tvm.relax.Tuple((output, mean, rstd)) + return self.block_builder.emit(tuple_res) + + def _group_norm(self, node: fx.node.Node) -> relax.Var: + import torch # type: ignore + + x = self.env[node.args[0]] + module = self.named_modules[node.target] + + if module.affine: + gamma = self.params[module.weight] + beta = self.params[module.bias] + else: + gamma = relax.const(torch.ones_like(module.num_channels), x.checked_type) + beta = relax.const(torch.zeros_like(module.num_channels), x.checked_type) + + dim = len(self.shape_of(x)) + return self.block_builder.emit( + relax.op.nn.group_norm( + x, + gamma, + beta, + num_groups=module.num_groups, + channel_axis=1, + axes=list(range(2, dim)), + epsilon=module.eps, + ) + ) + + def _embedding(self, node: fx.node.Node) -> relax.Var: + x = self.env[node.args[0]] + module = self.named_modules[node.target] + weight = self.params[module.weight] + x = self.block_builder.emit(relax.op.astype(x, "int32")) + + ndim = x.struct_info.ndim + if ndim == 1: + return self.block_builder.emit(relax.op.take(weight, x, axis=0)) + else: + x_shape = x.struct_info.shape.values + emb_size = weight.struct_info.shape.values[-1] + x = self.block_builder.emit(relax.op.reshape(x, shape=[-1])) + embedding = self.block_builder.emit(relax.op.take(weight, x, axis=0)) + return self.block_builder.emit(relax.op.reshape(embedding, [*x_shape, emb_size])) + + def _interpolate(self, node: fx.node.Node) -> relax.Var: + # torch.nn.functional.interpolate( + # input, size=None, scale_factor=None, mode='nearest', align_corners=None, + # recompute_scale_factor=None, antialias=False) + # (TODO) this is a temporary implementation for interpolate that only considers NCHW layout + # it basically replicates the implementation in tvm.relay.frontend.pytorch + data = self.env[node.args[0]] + size = ( + node.args[1] + if len(node.args) > 1 + else (node.kwargs["size"] if "size" in node.kwargs else None) + ) + scale_factor = ( + node.args[2] + if len(node.args) > 2 + else (node.kwargs["scale_factor"] if "scale_factor" in node.kwargs else None) + ) + method = ( + node.args[3] + if len(node.args) > 3 + else (node.kwargs["mode"] if "mode" in node.kwargs else "nearest") + ) + align_corners = ( + node.args[4] + if len(node.args) > 4 + else (node.kwargs["align_corners"] if "align_corners" in node.kwargs else None) + ) + recompute_scale_factor = ( + node.args[5] + if len(node.args) > 5 + else ( + node.kwargs["recompute_scale_factor"] + if "recompute_scale_factor" in node.kwargs + else None + ) + ) + antialias = ( + node.args[6] + if len(node.args) > 6 + else (node.kwargs["antialias"] if "antialias" in node.kwargs else False) + ) + + assert recompute_scale_factor is None + assert antialias is False + + if size is None: + shape = self.shape_of(data) + assert isinstance(shape, relax.ShapeExpr) + if isinstance(scale_factor, tuple): + assert len(scale_factor) == len(shape) - 2 + size = tuple( + int(shape[i].value * scale_factor[i - 2]) for i in range(2, len(shape)) + ) + else: + size = tuple(int(shape[i].value * scale_factor) for i in range(2, len(shape))) + + if method.startswith("nearest"): + method = "nearest_neighbor" + elif method[0:2] == "bi": + method = method[2:] + + if method == "nearest_neighbor": + coord_trans = "asymmetric" + elif align_corners: + coord_trans = "align_corners" + else: + coord_trans = "half_pixel" + + return self.block_builder.emit( + relax.op.image.resize2d( + data, size, layout="NCHW", method=method, coordinate_transformation_mode=coord_trans + ) + ) + + def _cross_entropy(self, node: fx.node.Node) -> relax.Expr: + preds = self.env[node.args[0]] + targets = self.env[node.args[1]] + + # functional.cross_entropy + if node.target not in self.named_modules: + weights = node.kwargs["weight"] + if weights is not None: + weights = self.env[weights] + reduction = node.kwargs["reduction"] + ignore_index = node.kwargs["ignore_index"] + + return self.block_builder.emit( + relax.op.nn.nll_loss( + relax.op.nn.log_softmax(preds), targets, weights, reduction, ignore_index + ) + ) + + module = self.named_modules[node.target] + + weights = module.weight + if weights is not None: + if weights in self.params: + weights = self.params[weights] + else: + weights = relax.const(weights.numpy(), preds.struct_info.dtype) + reduction = module.reduction + ignore_index = module.ignore_index + + return self.block_builder.emit( + relax.op.nn.nll_loss( + relax.op.nn.log_softmax(preds), targets, weights, reduction, ignore_index + ) + ) + + def _scaled_dot_product_attention(self, node: fx.node.Node) -> relax.Var: + assert ( + len(node.args) <= 4 + ), "Dropout is not supported, and is_causal should be called by kwargs." + transpose_S_H = lambda tensor: relax.op.permute_dims(tensor, [0, 2, 1, 3]) + query = transpose_S_H(self.env[node.args[0]]) + key = transpose_S_H(self.env[node.args[1]]) + value = transpose_S_H(self.env[node.args[2]]) + causal_mask = "TopLeft" if node.kwargs.get("is_causal", False) else None + + if len(node.args) == 4: + mask = self.env[node.args[3]] + msg = "Only a float mask is supported for the attn_mask input." + assert "float" in mask.struct_info.dtype, msg + attn = relax.op.nn.attention(query, key, value, bias=mask, causal_mask=causal_mask) + else: + attn = relax.op.nn.attention(query, key, value, causal_mask=causal_mask) + + return self.block_builder.emit(attn) + + ########## Others ########## + + def _size(self, node: fx.node.Node) -> relax.Expr: + x = self.env[node.args[0]] + shape = self.shape_of(x) + if len(node.args) == 1: + assert isinstance(shape, relax.ShapeExpr) + return shape + assert len(node.args) == 2 + idx = node.args[1] + return self.shape_of(x)[idx].value + + def infer_shape(self, expr: relax.Expr): + return [int(val) for val in expr.struct_info.shape.values] + + def _getattr(self, node: fx.node.Node) -> relax.Var: + if isinstance(self.env[node.args[0]], relax.Expr): + if node.args[1] == "dtype": + return self.env[node.args[0]].struct_info.dtype + elif node.args[1] == "shape": + return self.shape_of(self.env[node.args[0]]) + return getattr(self.env[node.args[0]], node.args[1]) + + def _getitem(self, node: fx.node.Node) -> relax.Var: + import torch + + x = self.env[node.args[0]] + if isinstance(x, (list, tuple, relax.ShapeExpr, relax.Tuple)): + return x[node.args[1]] + elif isinstance(x, relax.Var): + if isinstance(x.struct_info, relax.TupleStructInfo): + return self.block_builder.emit(relax.TupleGetItem(x, node.args[1])) + + assert isinstance(x.struct_info, relax.TensorStructInfo) + take_indices = [] + take_axes = [] + stride_begin = [] + stride_end = [] + stride = [] + stride_axes = [] + expand_dim = [] + i = 0 + shape = self.shape_of(x) + non_ellipsis_cnt = 0 + for index in node.args[1]: + if isinstance(index, (int, slice, torch.fx.node.Node)): + non_ellipsis_cnt += 1 + for index in node.args[1]: + if isinstance(index, int): + stride_begin.append(index) + stride_end.append(index + 1) + stride.append(1) + stride_axes.append(i) + i = i + 1 + elif isinstance(index, slice): + stride_begin.append(0 if index.start is None else index.start) + stride_end.append(shape[i] if index.stop is None else index.stop) + stride.append(1 if index.step is None else index.step) + stride_axes.append(i) + i = i + 1 + elif index is None: + expand_dim.append(len(stride_axes) + len(expand_dim)) + elif index is Ellipsis: + for _ in range(len(shape) - non_ellipsis_cnt): + stride_begin.append(0) + stride_end.append(shape[i]) + stride.append(1) + stride_axes.append(i) + i += 1 + elif isinstance(index, torch.fx.node.Node): + node_index = self.env[index] + if not isinstance(node_index, relax.Expr): + raise ValueError( + "Unsupported index type for relax.op.take: " + str(type(node_index)) + ) + take_indices.append(node_index) + take_axes.append(i) + i = i + 1 + else: + raise ValueError("Unsupported index type: " + str(type(index))) + while i < len(shape): + stride_begin.append(0) + stride_end.append(shape[i]) + stride.append(1) + stride_axes.append(i) + i += 1 + taken = x + if len(take_indices) > 1: + raise ValueError("Multiple tensors as index not yet supported") + for each_index, each_axis in zip(take_indices, take_axes): + taken = self.block_builder.emit(relax.op.take(taken, each_index, each_axis)) + sliced = self.block_builder.emit( + relax.op.strided_slice(taken, stride_axes, stride_begin, stride_end, stride) + ) + sliced_shape = list(self.shape_of(sliced)) + for i in expand_dim: + sliced_shape.insert(i, 1) + return self.block_builder.emit(relax.op.reshape(sliced, sliced_shape)) + elif isinstance(x, relax.Constant): + dtype = x.struct_info.dtype + return relax.const(x.data.numpy()[node.args[1]], dtype) + else: + assert False + + def create_convert_map(self): + from torch import nn + from torch import fx + + self.convert_map: Dict[Union[nn.Module, str], Callable[[fx.node.Node], relax.Var]] = { + # core aten ops + "aten::sin": lambda node: self.block_builder.emit(relax.op.sin(self.env[node.args[0]])), + "aten::cos": lambda node: self.block_builder.emit(relax.op.cos(self.env[node.args[0]])), + "aten::tan": lambda node: self.block_builder.emit(relax.op.tan(self.env[node.args[0]])), + "aten::asin": lambda node: self.block_builder.emit(relax.op.asin(self.env[node.args[0]])), + "aten::acos": lambda node: self.block_builder.emit(relax.op.acos(self.env[node.args[0]])), + "aten::atan": lambda node: self.block_builder.emit(relax.op.atan(self.env[node.args[0]])), + "aten::sinh": lambda node: self.block_builder.emit(relax.op.sinh(self.env[node.args[0]])), + "aten::cosh": lambda node: self.block_builder.emit(relax.op.cosh(self.env[node.args[0]])), + "aten::tanh": lambda node: self.block_builder.emit(relax.op.tanh(self.env[node.args[0]])), + "aten::asinh": lambda node: self.block_builder.emit(relax.op.asinh(self.env[node.args[0]])), + "aten::acosh": lambda node: self.block_builder.emit(relax.op.acosh(self.env[node.args[0]])), + "aten::atanh": lambda node: self.block_builder.emit(relax.op.atanh(self.env[node.args[0]])), + "aten::exp": self._exp, + "aten::add.Tensor": self._add, + "aten::mul.Scalar": self._mul, + "aten::mul.Tensor": self._mul, + "aten::sub.Tensor": self._sub, + "aten::sigmoid": self._sigmoid, + "aten::sqrt": self._sqrt, + "aten::round": self._round, + "aten::arange.start_step": self._arange, + "aten::sum.dim_IntList": self._sum, + "aten::convolution": self._convolution, + "aten::addmm": self._addmm, + "aten::mm": self._matmul, + "aten::bmm": self._matmul, + "aten::cat": self._cat, + "aten::expand": self._expand, + "aten::permute": self._permute, + "aten::split_with_sizes": self._split, + "aten::cumsum": self._cumsum, + "aten::squeeze.dims": self._squeeze, + "aten::unsqueeze": lambda node: self.block_builder.emit( + relax.op.expand_dims(self.env[node.args[0]], node.args[1]) + ), + "aten::view": self._reshape, + "aten::argmax": self._argmax_argmin(relax.op.argmax), + "aten::argmin": self._argmax_argmin(relax.op.argmin), + "aten::_softmax": self._softmax, + "aten::_log_softmax": self._log_softmax, + "aten::native_dropout": lambda node: self.env[node.args[0]], + "aten::clamp": self._clamp, + "aten::relu": lambda node: self.block_builder.emit(relax.op.nn.relu(self.env[node.args[0]])), + "aten::gelu": self._gelu, + "aten::avg_pool2d": self._avg_pool2d, + "aten::_adaptive_avg_pool2d": self._adaptive_avg_pool2d(is_module=False), + "aten::native_layer_norm": self._layer_norm, + "aten::full": self._full, + "aten::mean.dim": self._mean, + "aten::rsqrt": self._rsqrt, + "aten::neg": self._neg, + "aten::maximum": self._maximum, + "aten::_to_copy": lambda x: x, + "aten::hardtanh": lambda x: x, + "aten::max_pool2d_with_indices": lambda x: x, + "aten::embedding": lambda x: x, + "aten::native_group_norm": lambda x: x, + "aten::select.int": lambda x: x, + "aten::slice.Tensor": lambda x: x, + "aten::le.Scalar": lambda x: x, + "aten::ge.Scalar": lambda x: x, + "aten::scalar_tensor": lambda x: x, + "aten::where.self": lambda x: x, + } + + ### [QIT ExportedProgram] + def get_parameter(self, + node: torch.fx.Node, + exported_program: torch.export.ExportedProgram) -> torch.Tensor: + param = None + if node.name in exported_program.graph_signature.inputs_to_parameters: + param = exported_program.state_dict[ + exported_program.graph_signature.inputs_to_parameters[node.name] + ].data + if node.name in exported_program.graph_signature.inputs_to_buffers: + param = exported_program.state_dict[ + exported_program.graph_signature.inputs_to_buffers[node.name] + ] + if param is not None: + # update node.meta["val"] to qualified QNN datatype (e.g. i64 to i32) + assert isinstance(param, torch.Tensor), "Expect parameter to be tensor" + param = param.type(node.meta["val"].dtype) + return param + ### [QIT ExportedProgram End] + + def from_exported_program( + self, + model, + input_info: List[Tuple[Tuple[int], str]], + keep_params_as_input: bool, + unwrap_unit_return_tuple: bool, + no_bind_return_tuple: bool, + ) -> tvm.IRModule: + """Convert a PyTorch FX GraphModule to a Relax program.""" + from torch import fx + + model = model.run_decompositions() + graph: fx.Graph = model.graph + + # Create input variables. + inputs = [] + input_params = [] + for node in graph.nodes: + if node.op == "placeholder": + shape = list(node.meta['tensor_meta'].shape) + dtype = str(node.meta['tensor_meta'].dtype).split('.')[-1] + self.env[node] = relax.Var(node.name, + relax.TensorStructInfo(shape, dtype)) + parameter = self.get_parameter(node, model) + if parameter is not None: + input_params.append(self.env[node]) + else: + inputs.append(self.env[node]) + + # Initialize the block builder with a function and a dataflow block. + func_name = "main" + self.block_builder = relax.BlockBuilder() + params = [] + + with self.block_builder.function(name=func_name, params=inputs.copy()): + output = None + with self.block_builder.dataflow(): + # Translate model parameters. + for node in graph.nodes: + param = self.get_parameter(node, model) + if param is not None: + shape = param.data.shape + dtype = self._convert_data_type(str(param.data.dtype)) + # if dtype in ("float32", "float16"): + if not keep_params_as_input: + self.params[param] = relax.const(param.data.cpu().numpy(), dtype) + self.env[node] = self.params[param] + # else: + # raise ValueError("Unsupported data type for model parameters: %s" % dtype) + # Translate the model. + for node in graph.nodes: + if node.op != "output" and len(node.users) == 0: + continue + if node.op == "placeholder": + continue + # param = self.get_parameter(node, model) + # if param is not None: + # continue + # assert len(inputs) > 0, "Provided inputs is less than actual inputs" + # if "grapharg" in node.meta and node.meta["grapharg"].fake_tensor is None: + # # Ignore sym input + # continue + + # self.env[node] = inputs.pop(0) + elif node.op == "output": + args = self.retrieve_args(node) + assert len(args) == 1 + + # return tuple + if isinstance(args[0], (tuple, list, relax.Tuple)): + # unit tuple + if unwrap_unit_return_tuple and len(args[0]) == 1: + output = self.block_builder.emit_output(args[0][0]) + elif no_bind_return_tuple: + output = [] + for ret in args[0]: + output.append(self.block_builder.emit_output(ret)) + + if output is None: + output = self.block_builder.emit_output(args[0]) + break + elif node.op == "get_attr": + self.env[node] = self._fetch_attr(model, node.target) + elif node.name == 'getitem': + self.env[node] = self.block_builder.emit(relax.TupleGetItem(self.env[node.args[0]], node.args[1])) + elif node.op == "call_function": + func_name = node.target.name() + assert ( + func_name in self.convert_map + ), f"Unsupported function type {func_name}" + self.env[node] = self.convert_map[func_name](node) + else: + raise ValueError(f"Unsupported op {node.op}") + assert output is not None + + self.block_builder.emit_func_output(output) + + mod = self.block_builder.get() + if keep_params_as_input: + mod["main"] = mod["main"].with_attr("params", params) + return mod + + +def from_exported_program( + model, + input_info: List[Tuple[Tuple[int], str]], + *, + keep_params_as_input: bool = False, + unwrap_unit_return_tuple: bool = True, + no_bind_return_tuple: bool = False, +) -> tvm.IRModule: + """Convert a PyTorch FX GraphModule to a Relax program + + Parameters + ---------- + model : ExportedProgram + The PyTorch FX GraphModule to convert. + + input_info : List[Tuple[Tuple[int], str]] + A list of shapes and data types of input tensors. + + keep_params_as_input : bool + Whether to keep model parameters as input variables. + + unwrap_unit_return_tuple : bool + A boolean flag indicating if to the return value when it is an unit tuple. + When the return value is not a unit tuple, no unwrap will take place. + + no_bind_return_tuple : bool + A boolean flag indicating whether to bind the return tuple as a relax var. + If the flag is true and the return value is a tuple, it will not bind it to a var. + + Returns + ------- + output : tvm.IRModule + The import result IRModule, with the function "main" containing the + translated logic. + If `keep_params_as_input` is true, the "main" function have an attribute + "params" that contains the weights of the input model. The weights + can be detached by `relax.frontend.detach_params`. + + Examples + -------- + Users can use the FX tracer or torch.export.export() to export + a ExportedProgram from a PyTorch model. The following codes show + how to convert a PyTorch model to a Relax program. + + .. code-block:: python + + # Import the importer. + import numpy as np + import torch + from tvm.relax.frontend.torch_fx import from_exported_program + from torch import export + + # Define the module + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(in_features=10, out_features=7, bias=True) + + def forward(self, input): + return self.linear(input) + + # Instantiate the model and create the input info dict. + torch_model = MyModule() + input_info = [((128, 10), "float32")] + input_tensors = [ + torch.astensor(np.random.randn(*shape).astype(dtype)) + for shape, dtype in input_info + ] + + # Use FX tracer to trace the PyTorch model. + graph_module = fx.symbolic_trace(torch_model) + + # Use the torch.export.export() to export the PyTorch model to ExportedProgram. + try: + graph_module = torch.export.export(torch_model, *input_tensors) + except: + raise RuntimeError("Failed to export the PyTorch model to ExportedProgram.") + + # Use the importer to import the PyTorch model to Relax. + mod: tvm.IRModule = from_exported_program(graph_module, input_info) + + # Print out the imported model. + print(mod.script()) + + Notes + ----- + For a given PyTorch model, to lookup the names of the model inputs in + FX, one can use + + .. code-block:: python + + torch.export.export(model).graph.print_tabular() + + to print out the tabular representation of the PyTorch module, and then + check the placeholder rows in the beginning of the tabular. + """ + return TorchFXImporter().from_exported_program( + model, input_info, keep_params_as_input, unwrap_unit_return_tuple, no_bind_return_tuple + ) diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py new file mode 100644 index 000000000000..1ec6fe65a835 --- /dev/null +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -0,0 +1,1374 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest +import torch +import torch.nn.functional as F +from torch import fx +from torch.nn import Module + +import tvm +from tvm import relax +import tvm.testing +from tvm.script import ir as I +from tvm.script import relax as R +from tvm.script import tir as T +from tvm.relax.frontend import detach_params +from tvm.relax.frontend.torch import from_exported_program + + +def verify_model(torch_model, input_info, binding, expected): + dummy_inputs = [] + for i in range(len(input_info)): + input_shape, dtype = input_info[i] + if dtype=='int32': + dtype = torch.int32 + else: + dtype = torch.float32 + dummy_input = torch.zeros(*tuple(input_shape), dtype=dtype) + dummy_inputs.append(dummy_input) + exported_program = torch.export.export(torch_model, tuple(dummy_inputs)) + with torch.no_grad(): + mod = from_exported_program(exported_program, input_info) + binding = {k: tvm.nd.array(v) for k, v in binding.items()} + expected = relax.transform.BindParams("main", binding)(expected) + tvm.ir.assert_structural_equal(mod, expected) + + + +def test_conv1d(): + class Conv1D1(Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv1d(3, 6, 7, bias=True) + + def forward(self, input): + return self.conv(input) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10), dtype="float32"), + w1: R.Tensor((6, 3, 7), dtype="float32"), + w2: R.Tensor((6,), dtype="float32"), + ) -> R.Tensor((1, 6, 4), dtype="float32"): + # block 0 + with R.dataflow(): + lv1: R.Tensor((1, 6, 4), dtype="float32") = R.nn.conv1d( + input_1, + w1, + strides=[1], + padding=[0, 0], + dilation=[1], + data_layout="NCW", + kernel_layout="OIW", + out_layout="NCW", + out_dtype="float32", + ) + lv2: R.Tensor((1, 6, 1)) = R.reshape(w2, [1, 6, 1]) + gv: R.Tensor((1, 6, 4), dtype="float32") = R.add(lv1, lv2) + R.output(gv) + return gv + + class Conv1D2(Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv1d(3, 6, 7, bias=False) + + def forward(self, input): + return self.conv(input) + + @tvm.script.ir_module + class expected2: + @R.function + def main( + input_1: R.Tensor((1, 3, 10), dtype="float32"), + w1: R.Tensor((6, 3, 7), dtype="float32"), + ) -> R.Tensor((1, 6, 4), dtype="float32"): + # block 0 + with R.dataflow(): + gv: R.Tensor((1, 6, 4), dtype="float32") = R.nn.conv1d( + input_1, + w1, + strides=[1], + padding=[0, 0], + dilation=[1], + data_layout="NCW", + kernel_layout="OIW", + out_layout="NCW", + out_dtype="float32", + ) + R.output(gv) + return gv + + input_info = [([1, 3, 10], "float32")] + + model = Conv1D1() + binding = {"w1": model.conv.weight.detach().numpy(), "w2": model.conv.bias.detach().numpy()} + verify_model(model, input_info, binding, expected1) + + model = Conv1D2() + binding = {"w1": model.conv.weight.detach().numpy()} + verify_model(model, input_info, binding, expected2) + + +def test_conv1d_transpose(): + class ConvTranspose1d1(Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.ConvTranspose1d(6, 6, 3, bias=True) + + def forward(self, input): + return self.conv(input) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 6, 4), dtype="float32"), + w1: R.Tensor((6, 6, 3), dtype="float32"), + w2: R.Tensor((6,), dtype="float32"), + ) -> R.Tensor((1, 6, 6), dtype="float32"): + # block 0 + with R.dataflow(): + lv1: R.Tensor((1, 6, 6), dtype="float32") = R.nn.conv1d_transpose( + input_1, + w1, + strides=[1], + padding=[0, 0], + dilation=[1], + data_layout="NCW", + kernel_layout="IOW", + out_layout="NCW", + out_dtype="float32", + ) + lv2: R.Tensor((1, 6, 1)) = R.reshape(w2, [1, 6, 1]) + gv: R.Tensor((1, 6, 6), dtype="float32") = R.add(lv1, lv2) + R.output(gv) + return gv + + class ConvTranspose1d2(Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.ConvTranspose1d(6, 6, 3, bias=False) + + def forward(self, input): + return self.conv(input) + + @tvm.script.ir_module + class expected2: + @R.function + def main( + input_1: R.Tensor((1, 6, 4), dtype="float32"), + w1: R.Tensor((6, 6, 3), dtype="float32"), + ) -> R.Tensor((1, 6, 6), dtype="float32"): + # block 0 + with R.dataflow(): + gv: R.Tensor((1, 6, 6), dtype="float32") = R.nn.conv1d_transpose( + input_1, + w1, + strides=[1], + padding=[0, 0], + dilation=[1], + data_layout="NCW", + kernel_layout="IOW", + out_layout="NCW", + out_dtype="float32", + ) + R.output(gv) + return gv + + input_info = [([1, 6, 4], "float32")] + + model = ConvTranspose1d1() + binding = {"w1": model.conv.weight.detach().numpy(), "w2": model.conv.bias.detach().numpy()} + verify_model(model, input_info, binding, expected1) + + model = ConvTranspose1d2() + binding = {"w1": model.conv.weight.detach().numpy()} + verify_model(model, input_info, binding, expected2) + + +def test_conv2d(): + class Conv2D1(Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 6, 7, bias=True) + + def forward(self, input): + return self.conv(input) + + class Conv2D1Func(Module): + def __init__(self): + super().__init__() + self.weight = torch.randn(size=[6, 3, 7, 7]) + self.bias = torch.randn(size=[6]) + + def forward(self, input): + return torch.nn.functional.conv2d(input, self.weight, self.bias) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + w1: R.Tensor((6, 3, 7, 7), dtype="float32"), + w2: R.Tensor((6,), dtype="float32"), + ) -> R.Tensor((1, 6, 4, 4), dtype="float32"): + # block 0 + with R.dataflow(): + lv1: R.Tensor((1, 6, 4, 4), dtype="float32") = R.nn.conv2d( + input_1, + w1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + data_layout="NCHW", + kernel_layout="OIHW", + out_layout="NCHW", + out_dtype="float32", + ) + lv2: R.Tensor((1, 6, 1, 1)) = R.reshape(w2, [1, 6, 1, 1]) + gv: R.Tensor((1, 6, 4, 4), dtype="float32") = R.add(lv1, lv2) + R.output(gv) + return gv + + class Conv2D2(Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 6, 7, bias=False) + + def forward(self, input): + return self.conv(input) + + @tvm.script.ir_module + class expected2: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + w1: R.Tensor((6, 3, 7, 7), dtype="float32"), + ) -> R.Tensor((1, 6, 4, 4), dtype="float32"): + # block 0 + with R.dataflow(): + gv: R.Tensor((1, 6, 4, 4), dtype="float32") = R.nn.conv2d( + input_1, + w1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + data_layout="NCHW", + kernel_layout="OIHW", + out_layout="NCHW", + out_dtype="float32", + ) + R.output(gv) + return gv + + input_info = [([1, 3, 10, 10], "float32")] + + model = Conv2D1() + binding = {"w1": model.conv.weight.detach().numpy(), "w2": model.conv.bias.detach().numpy()} + verify_model(model, input_info, binding, expected1) + + model = Conv2D1Func() + binding = {"w1": model.weight.numpy(), "w2": model.bias.numpy()} + verify_model(model, input_info, binding, expected1) + + model = Conv2D2() + binding = {"w1": model.conv.weight.detach().numpy()} + verify_model(model, input_info, binding, expected2) + + +def test_conv2d_transpose(): + class ConvTranspose2d1(Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.ConvTranspose2d(3, 3, 7, bias=True) + + def forward(self, input): + return self.conv(input) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + w1: R.Tensor((3, 3, 7, 7), dtype="float32"), + w2: R.Tensor((3,), dtype="float32"), + ) -> R.Tensor((1, 3, 16, 16), dtype="float32"): + # block 0 + with R.dataflow(): + lv1: R.Tensor((1, 3, 16, 16), dtype="float32") = R.nn.conv2d_transpose( + input_1, + w1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + data_layout="NCHW", + kernel_layout="IOHW", + out_layout="NCHW", + out_dtype="float32", + ) + lv2: R.Tensor((1, 3, 1, 1)) = R.reshape(w2, [1, 3, 1, 1]) + gv: R.Tensor((1, 3, 16, 16), dtype="float32") = R.add(lv1, lv2) + R.output(gv) + return gv + + class ConvTranspose2d2(Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.ConvTranspose2d(3, 3, 7, bias=False) + + def forward(self, input): + return self.conv(input) + + @tvm.script.ir_module + class expected2: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + w1: R.Tensor((3, 3, 7, 7), dtype="float32"), + ) -> R.Tensor((1, 3, 16, 16), dtype="float32"): + # block 0 + with R.dataflow(): + gv: R.Tensor((1, 3, 16, 16), dtype="float32") = R.nn.conv2d_transpose( + input_1, + w1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + data_layout="NCHW", + kernel_layout="IOHW", + out_layout="NCHW", + out_dtype="float32", + ) + R.output(gv) + return gv + + input_info = [([1, 3, 10, 10], "float32")] + + model = ConvTranspose2d1() + binding = {"w1": model.conv.weight.detach().numpy(), "w2": model.conv.bias.detach().numpy()} + verify_model(model, input_info, binding, expected1) + + model = ConvTranspose2d2() + binding = {"w1": model.conv.weight.detach().numpy()} + verify_model(model, input_info, binding, expected2) + + +def test_conv3d(): + class Conv3D1(Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv3d(3, 6, 7, bias=True) + + def forward(self, input): + return self.conv(input) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10, 10), dtype="float32"), + w1: R.Tensor((6, 3, 7, 7, 7), dtype="float32"), + w2: R.Tensor((6,), dtype="float32"), + ) -> R.Tensor((1, 6, 4, 4, 4), dtype="float32"): + # block 0 + with R.dataflow(): + lv1: R.Tensor((1, 6, 4, 4, 4), dtype="float32") = R.nn.conv3d( + input_1, + w1, + strides=[1], + padding=[0, 0, 0], + dilation=[1], + data_layout="NCDHW", + kernel_layout="OIDHW", + out_layout="NCDHW", + out_dtype="float32", + ) + lv2: R.Tensor((1, 6, 1, 1, 1)) = R.reshape(w2, [1, 6, 1, 1, 1]) + gv: R.Tensor((1, 6, 4, 4, 4), dtype="float32") = R.add(lv1, lv2) + R.output(gv) + return gv + + class Conv3D2(Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv3d(3, 6, 7, bias=False) + + def forward(self, input): + return self.conv(input) + + @tvm.script.ir_module + class expected2: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10, 10), dtype="float32"), + w1: R.Tensor((6, 3, 7, 7, 7), dtype="float32"), + ) -> R.Tensor((1, 6, 4, 4, 4), dtype="float32"): + # block 0 + with R.dataflow(): + gv: R.Tensor((1, 6, 4, 4, 4), dtype="float32") = R.nn.conv3d( + input_1, + w1, + strides=[1], + padding=[0, 0, 0], + dilation=[1], + data_layout="NCDHW", + kernel_layout="OIDHW", + out_layout="NCDHW", + out_dtype="float32", + ) + R.output(gv) + return gv + + input_info = [([1, 3, 10, 10, 10], "float32")] + + model = Conv3D1() + binding = {"w1": model.conv.weight.detach().numpy(), "w2": model.conv.bias.detach().numpy()} + verify_model(model, input_info, binding, expected1) + + model = Conv3D2() + binding = {"w1": model.conv.weight.detach().numpy()} + verify_model(model, input_info, binding, expected2) + + +def test_linear(): + # nn.Linear + class Dense1(Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(10, 7, bias=True) + + def forward(self, input): + return self.linear(input) + + class Dense1Func(Module): + def __init__(self): + super().__init__() + self.weight = torch.randn(size=[7, 10]) + self.bias = torch.randn(size=[7]) + + def forward(self, input): + return torch.nn.functional.linear(input, self.weight, self.bias) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + w1: R.Tensor((7, 10), dtype="float32"), + w2: R.Tensor((7,), dtype="float32"), + ) -> R.Tensor((1, 3, 10, 7), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((30, 10), dtype="float32") = R.reshape(input_1, R.shape([30, 10])) + lv1: R.Tensor((10, 7), dtype="float32") = R.permute_dims(w1, axes=[1, 0]) + lv2: R.Tensor((30, 7), dtype="float32") = R.matmul(lv, lv1, out_dtype="float32") + lv3: R.Tensor((30, 7), dtype="float32") = R.add(w2, lv2) + lv4: R.Tensor((1, 3, 10, 7), dtype="float32") = R.reshape(lv3, R.shape([1, 3, 10, 7])) + gv: R.Tensor((1, 3, 10, 7), dtype="float32") = lv4 + R.output(gv) + return gv + + class Dense2(Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(10, 7, bias=False) + + def forward(self, input): + return self.linear(input) + + @tvm.script.ir_module + class expected2: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + w1: R.Tensor((7, 10), dtype="float32"), + ) -> R.Tensor((1, 3, 10, 7), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((10, 7), dtype="float32") = R.permute_dims(w1, axes=[1, 0]) + lv1: R.Tensor((30, 10), dtype="float32") = R.reshape(input_1, R.shape([30, 10])) + lv2: R.Tensor((30, 7), dtype="float32") = R.matmul(lv1, lv, out_dtype="float32") + lv3: R.Tensor((1, 3, 10, 7), dtype="float32") = R.reshape(lv2, R.shape([1, 3, 10, 7])) + gv: R.Tensor((1, 3, 10, 7), dtype="float32") = lv3 + R.output(gv) + return gv + + input_info = [([1, 3, 10, 10], "float32")] + + model = Dense1() + binding = {"w1": model.linear.weight.detach().numpy(), "w2": model.linear.bias.detach().numpy()} + verify_model(model, input_info, binding, expected1) + + model = Dense1Func() + binding = {"w1": model.weight.numpy(), "w2": model.bias.numpy()} + verify_model(model, input_info, binding, expected1) + + model = Dense2() + binding = {"w1": model.linear.weight.detach().numpy()} + verify_model(model, input_info, binding, expected2) + + # matmul + class MatMul1(Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + return torch.matmul(x, y) + + @tvm.script.ir_module + class expected3: + @R.function + def main( + input_1: R.Tensor((10, 10), dtype="float32"), + input_2: R.Tensor((10, 10), dtype="float32"), + ) -> R.Tensor((10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((10, 10), dtype="float32") = R.matmul( + input_1, input_2, out_dtype="float32" + ) + gv: R.Tensor((10, 10), dtype="float32") = lv + R.output(gv) + return gv + + verify_model( + MatMul1(), + [([10, 10], "float32"), ([10, 10], "float32")], + {}, + expected3, + ) + + +def test_bmm(): + class BMM(Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + return torch.bmm(x, y) + + @tvm.script.ir_module + class Expected: + @R.function + def main( + input_1: R.Tensor((4, 128, 256), dtype="float32"), + input_2: R.Tensor((4, 256, 512), dtype="float32"), + ) -> R.Tensor((4, 128, 512), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((4, 128, 512), dtype="float32") = R.matmul( + input_1, input_2, out_dtype="float32" + ) + gv: R.Tensor((4, 128, 512), dtype="float32") = lv + R.output(gv) + return gv + + verify_model( + BMM(), + [((4, 128, 256), "float32"), ((4, 256, 512), "float32")], + {}, + Expected, + ) + + +def test_baddbmm(): + class BAddBMM1(Module): + def __init__(self): + super().__init__() + + def forward(self, c, x, y): + return torch.baddbmm(c, x, y) + + class BAddBMM2(Module): + def __init__(self): + super().__init__() + + def forward(self, c, x, y): + return torch.baddbmm(c, x, y, alpha=2, beta=0) + + @tvm.script.ir_module + class Expected1: + @R.function + def main( + inp_0: R.Tensor((4, 128, 512), dtype="float32"), + inp_1: R.Tensor((4, 128, 256), dtype="float32"), + inp_2: R.Tensor((4, 256, 512), dtype="float32"), + ) -> R.Tensor((4, 128, 512), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((4, 128, 512), dtype="float32") = R.matmul(inp_1, inp_2, out_dtype="float32") + lv1: R.Tensor((4, 128, 512), dtype="float32") = R.add(inp_0, lv) + gv: R.Tensor((4, 128, 512), dtype="float32") = lv1 + R.output(gv) + return gv + + @tvm.script.ir_module + class Expected2: + @R.function + def main( + inp_0: R.Tensor((4, 128, 512), dtype="float32"), + inp_1: R.Tensor((4, 128, 256), dtype="float32"), + inp_2: R.Tensor((4, 256, 512), dtype="float32"), + ) -> R.Tensor((4, 128, 512), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((4, 128, 512), dtype="float32") = R.matmul(inp_1, inp_2, out_dtype="float32") + lv1: R.Tensor((4, 128, 512), dtype="float32") = R.multiply( + lv, R.const(2, "float32") + ) + gv: R.Tensor((4, 128, 512), dtype="float32") = lv1 + R.output(gv) + return gv + + verify_model( + BAddBMM1(), + [((4, 128, 512), "float32"), ((4, 128, 256), "float32"), ((4, 256, 512), "float32")], + {}, + Expected1, + ) + + verify_model( + BAddBMM2(), + [((4, 128, 512), "float32"), ((4, 128, 256), "float32"), ((4, 256, 512), "float32")], + {}, + Expected2, + ) + + +def test_relu(): + class ReLU0(Module): + def __init__(self): + super().__init__() + self.relu = torch.nn.ReLU() + + def forward(self, input): + return self.relu(input) + + class ReLU1(Module): + def forward(self, input): + return torch.nn.functional.relu(input) + + @tvm.script.ir_module + class expected: + @R.function + def main( + input_1: R.Tensor((10, 10), dtype="float32") + ) -> R.Tensor((10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((10, 10), dtype="float32") = R.nn.relu(input_1) + gv: R.Tensor((10, 10), dtype="float32") = lv + R.output(gv) + return gv + + input_info = [([10, 10], "float32")] + verify_model(ReLU0(), input_info, {}, expected) + verify_model(ReLU1(), input_info, {}, expected) + + +def test_layernorm(): + input_info = [([1, 3, 10, 10], "float32")] + + class LayerNorm(Module): + def __init__(self): + super().__init__() + self.ln = torch.nn.LayerNorm((10, 10)) + + def forward(self, input): + return self.ln(input) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + w1: R.Tensor((10, 10), dtype="float32"), + w2: R.Tensor((10, 10), dtype="float32"), + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.layer_norm( + input_1, + w1, + w2, + axes=[-2, -1], + epsilon=1e-05, + center=True, + scale=True, + ) + lv1: R.Tensor((1, 3), dtype="float32") = R.mean(input_1, axis=[-2, -1], keepdims=False) + lv2: R.Tensor((), dtype="float32") = R.variance(input_1, axis=None, keepdims=False) + lv3: R.Tensor((), dtype="float32") = R.add(lv2, R.const(1e-5, "float32")) + lv4: R.Tensor((), dtype="float32") = R.rsqrt(lv3) + lv5: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32"), R.Tensor((1, 3), dtype="float32"), R.Tensor((), dtype="float32")) = lv, lv1, lv4 + lv6: R.Tensor((1, 3, 10, 10), dtype="float32") = lv5[0] + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv6 + R.output(gv) + return gv + + model = LayerNorm() + binding = { + "w1": model.ln.weight.detach().numpy(), + "w2": model.ln.bias.detach().numpy(), + } + verify_model(LayerNorm(), input_info, binding, expected1) + + +def test_unary(): + input_info = [([1, 3, 10, 10], "float32")] + + # sin + class Sin(Module): + def forward(self, input): + return torch.sin(input) + + @tvm.script.ir_module + class expected_sin: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.sin(input_1) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Sin(), input_info, {}, expected_sin) + + # cos + class Cos(Module): + def forward(self, input): + return torch.cos(input) + + @tvm.script.ir_module + class expected_cos: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.cos(input_1) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Cos(), input_info, {}, expected_cos) + + # tan + class Tan(Module): + def forward(self, input): + return torch.tan(input) + + @tvm.script.ir_module + class expected_tan: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.tan(input_1) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Tan(), input_info, {}, expected_tan) + + # asin + class Asin(Module): + def forward(self, input): + return torch.asin(input) + + @tvm.script.ir_module + class expected_asin: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.asin(input_1) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Asin(), input_info, {}, expected_asin) + + # acos + class Acos(Module): + def forward(self, input): + return torch.acos(input) + + @tvm.script.ir_module + class expected_acos: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.acos(input_1) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Acos(), input_info, {}, expected_acos) + + # atan + class Atan(Module): + def forward(self, input): + return torch.atan(input) + + @tvm.script.ir_module + class expected_atan: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.atan(input_1) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Atan(), input_info, {}, expected_atan) + + # sinh + class Sinh(Module): + def forward(self, input): + return torch.sinh(input) + + @tvm.script.ir_module + class expected_sinh: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.sinh(input_1) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Sinh(), input_info, {}, expected_sinh) + + # cosh + class Cosh(Module): + def forward(self, input): + return torch.cosh(input) + + @tvm.script.ir_module + class expected_cosh: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.cosh(input_1) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Cosh(), input_info, {}, expected_cosh) + + # tanh + class Tanh(Module): + def forward(self, input): + return torch.tanh(input) + + @tvm.script.ir_module + class expected_tanh: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.tanh(input_1) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Tanh(), input_info, {}, expected_tanh) + + # asinh + class Asinh(Module): + def forward(self, input): + return torch.asinh(input) + + @tvm.script.ir_module + class expected_asinh: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.asinh(input_1) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Asinh(), input_info, {}, expected_asinh) + + # acosh + class Acosh(Module): + def forward(self, input): + return torch.acosh(input) + + @tvm.script.ir_module + class expected_acosh: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.acosh(input_1) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Acosh(), input_info, {}, expected_acosh) + + # atanh + class Atanh(Module): + def forward(self, input): + return torch.atanh(input) + + @tvm.script.ir_module + class expected_atanh: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.atanh(input_1) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Atanh(), input_info, {}, expected_atanh) + + # exp + class Exp(Module): + def forward(self, input): + return torch.exp(input) + + @tvm.script.ir_module + class expected_exp: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.exp(input_1) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Exp(), input_info, {}, expected_exp) + + # sqrt + class Sqrt(Module): + def forward(self, input): + return torch.sqrt(input) + + @tvm.script.ir_module + class expected3: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.sqrt(input_1) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Sqrt(), input_info, {}, expected3) + + # sigmoid + class Sigmoid(Module): + def __init__(self): + super().__init__() + self.sigmoid = torch.nn.Sigmoid() + + def forward(self, input): + return self.sigmoid(input) + + class Sigmoid2(Module): + def forward(self, input): + return torch.sigmoid(input) + + @tvm.script.ir_module + class expected4: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.sigmoid(input_1) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Sigmoid(), input_info, {}, expected4) + verify_model(Sigmoid2(), input_info, {}, expected4) + + # round + class Round(Module): + def forward(self, input): + return torch.round(input) + + @tvm.script.ir_module + class expected5: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.round(input_1) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Round(), input_info, {}, expected5) + + +def test_expand(): + input_info = [([1, 2, 3, 4], "float32")] + + class Expand(Module): + def forward(self, x): + return x.expand(4, 2, 3, 4) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + x: R.Tensor((1, 2, 3, 4), dtype="float32") + ) -> R.Tensor((4, 2, 3, 4), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((4, 2, 3, 4), dtype="float32") = R.broadcast_to(x, (4, 2, 3, 4)) + gv: R.Tensor((4, 2, 3, 4), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Expand(), input_info, {}, expected1) + + +def test_reduce(): + input_info = [([1, 2, 3, 4], "float32")] + + # sum + class Sum(Module): + def forward(self, x): + return torch.sum(x, (2, 1)) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + inp_0: R.Tensor((1, 2, 3, 4), dtype="float32") + ) -> R.Tensor((1, 4), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 4), dtype="float32") = R.sum(inp_0, axis=[2, 1], keepdims=False) + gv: R.Tensor((1, 4), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Sum(), input_info, {}, expected1) + + +def test_permute(): + input_info = [([1, 2, 3, 4], "float32")] + + class Permute(Module): + def forward(self, x): + return x.permute(0, 3, 2, 1) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + x: R.Tensor((1, 2, 3, 4), dtype="float32") + ) -> R.Tensor((1, 4, 3, 2), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 4, 3, 2), dtype="float32") = R.permute_dims(x, axes=[0, 3, 2, 1]) + gv: R.Tensor((1, 4, 3, 2), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Permute(), input_info, {}, expected1) + + +def test_reshape(): + input_info = [([1, 2, 3, 4], "float32")] + + class Reshape(Module): + def forward(self, x): + return x.reshape(2, 12) + + @tvm.script.ir_module + class expected1: + @R.function + def main(x: R.Tensor((1, 2, 3, 4), dtype="float32")) -> R.Tensor((2, 12), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((2, 12), dtype="float32") = R.reshape(x, (2, 12)) + gv: R.Tensor((2, 12), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Reshape(), input_info, {}, expected1) + + +def test_transpose(): + input_info = [([1, 2, 3, 4], "float32")] + + class Transpose(Module): + def forward(self, x): + return x.transpose(1, 3) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + x: R.Tensor((1, 2, 3, 4), dtype="float32") + ) -> R.Tensor((1, 4, 3, 2), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 4, 3, 2), dtype="float32") = R.permute_dims(x, axes=[0, 3, 2, 1]) + gv: R.Tensor((1, 4, 3, 2), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Transpose(), input_info, {}, expected1) + + +def test_view(): + input_info = [([1, 2, 3, 4], "float32")] + + class View(Module): + def forward(self, x): + return x.view(2, 12) + + @tvm.script.ir_module + class expected1: + @R.function + def main(x: R.Tensor((1, 2, 3, 4), dtype="float32")) -> R.Tensor((2, 12), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((2, 12), dtype="float32") = R.reshape(x, (2, 12)) + gv: R.Tensor((2, 12), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(View(), input_info, {}, expected1) + + +def test_argmax(): + class Argmax1(Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, input): + return torch.argmax(input, dim=-1) + + class Argmax2(Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, input): + return torch.argmax(input, dim=-1, keepdim=True) + + @tvm.script.ir_module + class Expected1: + @R.function + def main(inp_0: R.Tensor((256, 256), dtype="float32")) -> R.Tensor((256,), dtype="int64"): + with R.dataflow(): + lv: R.Tensor((256,), dtype="int64") = R.argmax(inp_0, axis=-1, keepdims=False) + gv: R.Tensor((256,), dtype="int64") = lv + R.output(gv) + return gv + + @tvm.script.ir_module + class Expected2: + @R.function + def main(inp_0: R.Tensor((256, 256), dtype="float32")) -> R.Tensor((256, 1), dtype="int64"): + with R.dataflow(): + lv: R.Tensor((256, 1), dtype="int64") = R.argmax(inp_0, axis=-1, keepdims=True) + gv: R.Tensor((256, 1), dtype="int64") = lv + R.output(gv) + return gv + + verify_model(Argmax1(), [([256, 256], "float32")], {}, Expected1) + verify_model(Argmax2(), [([256, 256], "float32")], {}, Expected2) + + +def test_argmin(): + class Argmin1(Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, input): + return torch.argmin(input) + + class Argmin2(Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, input): + return torch.argmin(input, keepdim=True) + + @tvm.script.ir_module + class Expected1: + @R.function + def main(inp_0: R.Tensor((256, 256), dtype="float32")) -> R.Tensor((), dtype="int64"): + with R.dataflow(): + lv: R.Tensor((), dtype="int64") = R.argmin(inp_0, axis=None, keepdims=False) + gv: R.Tensor((), dtype="int64") = lv + R.output(gv) + return gv + + @tvm.script.ir_module + class Expected2: + @R.function + def main(inp_0: R.Tensor((256, 256), dtype="float32")) -> R.Tensor((1, 1), dtype="int64"): + with R.dataflow(): + lv: R.Tensor((1, 1), dtype="int64") = R.argmin(inp_0, axis=None, keepdims=True) + gv: R.Tensor((1, 1), dtype="int64") = lv + R.output(gv) + return gv + + verify_model(Argmin1(), [([256, 256], "float32")], {}, Expected1) + verify_model(Argmin2(), [([256, 256], "float32")], {}, Expected2) + + +def test_mean(): + class Mean(Module): + def forward(self, input): + return input.mean(-1) + + class MeanKeepDim(Module): + def forward(self, input): + return input.mean(-1, keepdim=True) + + @I.ir_module + class Expected1: + @R.function + def main(inp_0: R.Tensor((256, 256), dtype="float32")) -> R.Tensor((256,), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((256,), dtype="float32") = R.mean(inp_0, axis=[-1], keepdims=False) + gv: R.Tensor((256,), dtype="float32") = lv + R.output(gv) + return gv + + @I.ir_module + class Expected2: + @R.function + def main( + inp_0: R.Tensor((256, 256), dtype="float32") + ) -> R.Tensor((256, 1), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((256, 1), dtype="float32") = R.mean(inp_0, axis=[-1], keepdims=True) + gv: R.Tensor((256, 1), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Mean(), [([256, 256], "float32")], {}, Expected1) + verify_model(MeanKeepDim(), [([256, 256], "float32")], {}, Expected2) + + +def test_rsqrt(): + class Rsqrt(Module): + def forward(self, input): + return torch.rsqrt(input) + + @I.ir_module + class Expected1: + @R.function + def main( + inp_0: R.Tensor((256, 256), dtype="float32") + ) -> R.Tensor((256, 256), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((256, 256), dtype="float32") = R.rsqrt(inp_0) + gv: R.Tensor((256, 256), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Rsqrt(), [([256, 256], "float32")], {}, Expected1) + + +def test_neg(): + class Neg(Module): + def forward(self, input): + return -input + + @I.ir_module + class Expected1: + @R.function + def main( + inp_0: R.Tensor((256, 256), dtype="float32") + ) -> R.Tensor((256, 256), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((256, 256), dtype="float32") = R.negative(inp_0) + gv: R.Tensor((256, 256), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Neg(), [([256, 256], "float32")], {}, Expected1) + + +def test_max(): + class Max(Module): + def forward(self, x, y): + return torch.max(x, y) + + @I.ir_module + class Expected1: + @R.function + def main( + inp_0: R.Tensor((256, 256), dtype="float32"), + inp_1: R.Tensor((256, 256), dtype="float32"), + ) -> R.Tensor((256, 256), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((256, 256), dtype="float32") = R.maximum(inp_0, inp_1) + gv: R.Tensor((256, 256), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Max(), [([256, 256], "float32"), ([256, 256], "float32")], {}, Expected1) + +if __name__ == "__main__": + tvm.testing.main()