diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index aed38d7c49ea..8d66343254c1 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -35,7 +35,7 @@ def __init__(self) -> None: import torch # type: ignore from torch import fx - self.env: Dict[fx.node.Node, relax.Expr] = {} + self.env: Dict[fx.Node, relax.Expr] = {} self.params: Dict[torch.Tensor, relax.Expr] = {} self.named_modules: Dict[str, torch.Module] = None self.block_builder: relax.BlockBuilder = None @@ -108,7 +108,7 @@ def retrieve_args(self, node): def _retrieve_args(self, node): from torch import fx - if isinstance(node, fx.node.Node): + if isinstance(node, fx.Node): return self.env[node] elif isinstance(node, tuple): return tuple(self._retrieve_args(x) for x in node) @@ -136,33 +136,113 @@ def _call_binary_op(self, op, lhs, rhs): lhs, rhs = TorchFXImporter._promote_binary_op_args(lhs, rhs) return self.block_builder.emit(op(lhs, rhs)) - ########## Arithmetic ########## + ########## Unary Ops ########## - def _exp(self, node: fx.node.Node) -> relax.Var: - return self.block_builder.emit(relax.op.exp(self.env[node.args[0]])) + def _unary_op(self, op: Callable) -> Callable: + from torch import fx - def _sigmoid(self, node: fx.node.Node) -> relax.Var: - return self.block_builder.emit(relax.op.sigmoid(self.env[node.args[0]])) + def convert(node: fx.Node) -> relax.Var: + return self.block_builder.emit(op(self.env[node.args[0]])) - def _sqrt(self, node: fx.node.Node) -> relax.Expr: - arg = self.env[node.args[0]] - if isinstance(arg, (int, float)): - arg = relax.const(arg, "float32") - return self.block_builder.emit(relax.op.sqrt(arg)) + return convert - def _rsqrt(self, node: fx.node.Node) -> relax.Expr: - arg = self.env[node.args[0]] - if isinstance(arg, (int, float)): - arg = relax.const(arg, "float32") - return self.block_builder.emit(relax.op.rsqrt(arg)) + def _clamp(self, node: fx.Node) -> relax.Expr: + args = self.retrieve_args(node) + a_min = args[1] if len(args) > 1 else node.kwargs["min"] + a_max = args[2] if len(args) > 2 else node.kwargs["max"] + if not isinstance(a_min, (int, float)): + raise ValueError( + f"TVM only supports constant min value for torch.clamp/clip, " + f"but got {a_min} with type {type(a_min)}" + ) + if not isinstance(a_max, (int, float)): + raise ValueError( + f"TVM only supports constant max value for torch.clamp/clip, " + f"but got {a_max} with type {type(a_max)}" + ) + return self.block_builder.emit(relax.op.clip(args[0], a_min, a_max)) + + def _gelu(self, node: fx.Node) -> relax.Expr: + approximate = node.kwargs.get("approximate", "none") + if approximate == "none": + return self.block_builder.emit(relax.op.nn.gelu(self.env[node.args[0]])) + elif approximate == "tanh": + return self.block_builder.emit(relax.op.nn.gelu_tanh(self.env[node.args[0]])) + else: + raise KeyError("Unregonized approximate algorithm for gelu: {}.".format(approximate)) + + def _hardsigmoid(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + x = args[0] + dtype = x.struct_info.dtype + x0 = relax.op.add(x, relax.const(3, dtype)) + x1 = relax.op.clip(x0, 0, 6) + return self.block_builder.emit(relax.op.divide(x1, relax.const(6, dtype))) + + def _hardswish(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + x = args[0] + dtype = x.struct_info.dtype + x0 = relax.op.add(x, relax.const(3, dtype)) + x1 = relax.op.clip(x0, 0, 6) + x2 = relax.op.divide(x1, relax.const(6, dtype)) + return self.block_builder.emit(relax.op.multiply(x, x2)) + + def _leakyrelu(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + alpha = node.args[1] if len(node.args) > 1 else node.kwargs.get("negative_slope", 0.01) + return self.block_builder.emit(relax.op.nn.leakyrelu(x, alpha)) + + def _leakyrelu_module(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + module = self.named_modules[node.target] + alpha = module.negative_slope + return self.block_builder.emit(relax.op.nn.leakyrelu(x, alpha)) + + def _log_softmax(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", -1) + return self.block_builder.emit(relax.op.nn.log_softmax(x, dim)) + + def _log_softmax_module(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + module = self.named_modules[node.target] + dim = module.dim + assert dim is not None + return self.block_builder.emit(relax.op.nn.log_softmax(x, dim)) - def _round(self, node: fx.node.Node) -> relax.Expr: - if "decimals" in node.kwargs and node.kwargs["decimals"] != 0: + def _round(self, node: fx.Node) -> relax.Expr: + if node.kwargs.get("decimals", 0) != 0: raise ValueError("specifying decimals for round is not supported yet") arg = self.env[node.args[0]] return self.block_builder.emit(relax.op.round(arg)) - def _add(self, node: fx.node.Node) -> relax.Expr: + def _softmax(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", -1) + return self.block_builder.emit(relax.op.nn.softmax(x, dim)) + + def _softmax_module(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + module = self.named_modules[node.target] + dim = module.dim + assert dim is not None + return self.block_builder.emit(relax.op.nn.softmax(x, dim)) + + def _tril_triu(self, op: Callable) -> Callable: + from torch import fx + + def convert(node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + k = node.args[1] if len(node.args) > 1 else node.kwargs.get("diagonal", 0) + assert isinstance(k, int) + return self.block_builder.emit(op(x, k)) + + return convert + + ########## Arithmetic ########## + + def _add(self, node: fx.Node) -> relax.Expr: lhs, rhs = self.retrieve_args(node) if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var): return self._call_binary_op(relax.op.add, lhs, rhs) @@ -176,103 +256,54 @@ def _add(self, node: fx.node.Node) -> relax.Expr: ) return lhs + rhs - def _max(self, node: fx.node.Node) -> relax.Expr: + def _max(self, node: fx.Node) -> relax.Expr: lhs, rhs = self.retrieve_args(node) if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var): return self._call_binary_op(relax.op.maximum, lhs, rhs) - def _floordiv(self, node: fx.node.Node) -> relax.Expr: + def _floordiv(self, node: fx.Node) -> relax.Expr: lhs, rhs = self.retrieve_args(node) if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var): return self._call_binary_op(relax.op.floor_divide, lhs, rhs) return lhs // rhs - def _mul(self, node: fx.node.Node) -> relax.Expr: + def _mul(self, node: fx.Node) -> relax.Expr: lhs, rhs = self.retrieve_args(node) if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var): return self._call_binary_op(relax.op.multiply, lhs, rhs) return lhs * rhs - def _pow(self, node: fx.node.Node) -> relax.Expr: + def _pow(self, node: fx.Node) -> relax.Expr: lhs, rhs = self.retrieve_args(node) if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var): return self._call_binary_op(relax.op.power, lhs, rhs) return lhs**rhs - def _neg(self, node: fx.node.Node) -> relax.Expr: - x = self.env[node.args[0]] - return self.block_builder.emit(relax.op.negative(x)) - - def _sub(self, node: fx.node.Node) -> relax.Expr: + def _sub(self, node: fx.Node) -> relax.Expr: lhs, rhs = self.retrieve_args(node) if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var): return self._call_binary_op(relax.op.subtract, lhs, rhs) return lhs - rhs - def _truediv(self, node: fx.node.Node) -> relax.Expr: + def _truediv(self, node: fx.Node) -> relax.Expr: lhs, rhs = self.retrieve_args(node) if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var): return self._call_binary_op(relax.op.divide, lhs, rhs) return lhs / rhs - def _clamp(self, node: fx.node.Node) -> relax.Expr: - args = self.retrieve_args(node) - a_min = node.kwargs["min"] - a_max = node.kwargs["max"] - if not isinstance(a_min, (int, float)): - raise ValueError( - f"TVM only supports constant min value for torch.clamp/clip, " - f"but got {a_min} with type {type(a_min)}" - ) - if not isinstance(a_max, (int, float)): - raise ValueError( - f"TVM only supports constant max value for torch.clamp/clip, " - f"but got {a_max} with type {type(a_max)}" - ) - return self.block_builder.emit(relax.op.clip(args[0], a_min, a_max)) - - def _gelu(self, node: fx.node.Node) -> relax.Expr: - if "approximate" not in node.kwargs: - approximate = "none" - else: - approximate = node.kwargs["approximate"] - if approximate == "none": - return self.block_builder.emit(relax.op.nn.gelu(self.env[node.args[0]])) - elif approximate == "tanh": - return self.block_builder.emit(relax.op.nn.gelu_tanh(self.env[node.args[0]])) - else: - raise KeyError("Unregonized approximate algorithm for gelu: {}.".format(approximate)) - - def _hardsigmoid(self, node: fx.node.Node) -> relax.Var: - args = self.retrieve_args(node) - x = args[0] - dtype = x.struct_info.dtype - x0 = relax.op.add(x, relax.const(3, dtype)) - x1 = relax.op.clip(x0, 0, 6) - return self.block_builder.emit(relax.op.divide(x1, relax.const(6, dtype))) - - def _hardswish(self, node: fx.node.Node) -> relax.Var: - args = self.retrieve_args(node) - x = args[0] - dtype = x.struct_info.dtype - x0 = relax.op.add(x, relax.const(3, dtype)) - x1 = relax.op.clip(x0, 0, 6) - x2 = relax.op.divide(x1, relax.const(6, dtype)) - return self.block_builder.emit(relax.op.multiply(x, x2)) - ########## Compare ########## - def _lt(self, node: fx.node.Node) -> relax.Expr: + def _lt(self, node: fx.Node) -> relax.Expr: lhs, rhs = self.retrieve_args(node) return self._call_binary_op(relax.op.less, lhs, rhs) - def _eq(self, node: fx.node.Node) -> relax.Expr: + def _eq(self, node: fx.Node) -> relax.Expr: lhs, rhs = self.retrieve_args(node) return self._call_binary_op(relax.op.equal, lhs, rhs) ########## Creation ########## - def _arange(self, node: fx.node.Node) -> relax.Var: + def _arange(self, node: fx.Node) -> relax.Var: import torch start_end_step = [None, None, None] @@ -311,15 +342,15 @@ def _arange(self, node: fx.node.Node) -> relax.Var: else: dtype = "int64" start_end_step = [ - self.env[x] if isinstance(x, torch.fx.node.Node) else x for x in start_end_step + self.env[x] if isinstance(x, torch.fx.Node) else x for x in start_end_step ] return self.block_builder.emit(relax.op.arange(*start_end_step, dtype=dtype)) - def _empty(self, node: fx.node.Node) -> relax.Var: + def _empty(self, node: fx.Node) -> relax.Var: dtype = TorchFXImporter._convert_data_type(str(node.kwargs["dtype"]), self.env) return self.block_builder.emit(relax.op.zeros(node.args, dtype)) - def _inplace_fill(self, node: fx.node.Node) -> relax.Var: + def _inplace_fill(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) x = args[0] dtype = x.struct_info.dtype @@ -328,7 +359,7 @@ def _inplace_fill(self, node: fx.node.Node) -> relax.Var: self.env[node.args[0]] = filled return filled - def _tensor(self, node: fx.node.Node) -> relax.Var: + def _tensor(self, node: fx.Node) -> relax.Var: dtype = node.kwargs["dtype"] if "dtype" in node.kwargs else None if isinstance(node.args[0], float): return relax.const(node.args[0], dtype if dtype is not None else "float32") @@ -336,21 +367,10 @@ def _tensor(self, node: fx.node.Node) -> relax.Var: return relax.const(node.args[0], dtype if dtype is not None else "int64") raise ValueError("torch.tensor with value not a float or int is not accepted") - def _tril_triu(self, op: Callable) -> Callable: - from torch import fx - - def convert(node: fx.node.Node) -> relax.Var: - x = self.env[node.args[0]] - k = node.args[1] if len(node.args) > 1 else 0 - assert isinstance(k, int) - return self.block_builder.emit(op(x, k)) - - return convert - def _inplace_tril_triu(self, op: Callable) -> Callable: from torch import fx - def convert(node: fx.node.Node) -> relax.Var: + def convert(node: fx.Node) -> relax.Var: x = self.env[node.args[0]] k = node.args[1] if len(node.args) > 1 else 0 assert isinstance(k, int) @@ -361,7 +381,7 @@ def convert(node: fx.node.Node) -> relax.Var: return convert - def _new_ones(self, node: fx.node.Node) -> relax.Var: + def _new_ones(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) self_var = args[0] size = args[1:] @@ -376,7 +396,7 @@ def _new_ones(self, node: fx.node.Node) -> relax.Var: ) ) - def _ones(self, node: fx.node.Node) -> relax.Var: + def _ones(self, node: fx.Node) -> relax.Var: import torch args = self.retrieve_args(node) @@ -397,7 +417,7 @@ def _ones(self, node: fx.node.Node) -> relax.Var: ) ) - def _full(self, node: fx.node.Node) -> relax.Var: + def _full(self, node: fx.Node) -> relax.Var: import torch args = self.retrieve_args(node) @@ -421,14 +441,14 @@ def _full(self, node: fx.node.Node) -> relax.Var: ########## Statistical ########## - def _sum(self, node: fx.node.Node) -> relax.Var: + def _sum(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) keepdim = node.kwargs["keepdim"] if "keepdim" in node.kwargs else False if len(args) == 1: return self.block_builder.emit(relax.op.sum(args[0], keepdims=keepdim)) return self.block_builder.emit(relax.op.sum(args[0], args[1])) - def _mean(self, node: fx.node.Node) -> relax.Var: + def _mean(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) keepdim = node.kwargs["keepdim"] if "keepdim" in node.kwargs else False if len(args) == 1: @@ -437,18 +457,18 @@ def _mean(self, node: fx.node.Node) -> relax.Var: ########## DataType ########## - def _float(self, node: fx.node.Node) -> relax.Var: + def _float(self, node: fx.Node) -> relax.Var: return self.block_builder.emit(relax.op.astype(self.env[node.args[0]], "float32")) - def _half(self, node: fx.node.Node) -> relax.Var: + def _half(self, node: fx.Node) -> relax.Var: return self.block_builder.emit(relax.op.astype(self.env[node.args[0]], "float16")) - def _type(self, node: fx.node.Node) -> relax.Var: + def _type(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] dtype = TorchFXImporter._convert_data_type(node.args[1], self.env) return self.block_builder.emit(relax.op.astype(x, dtype)) - def _to(self, node: fx.node.Node) -> relax.Var: + def _to(self, node: fx.Node) -> relax.Var: import torch x = self.env[node.args[0]] @@ -466,7 +486,7 @@ def _to(self, node: fx.node.Node) -> relax.Var: def _matmul_impl(self, a: relax.Expr, b: relax.Expr): return self.block_builder.emit(relax.op.linear_algebra.matmul(a, b, out_dtype="float32")) - def _matmul(self, node: fx.node.Node) -> relax.Var: + def _matmul(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) res = self._matmul_impl( args[0], @@ -474,7 +494,7 @@ def _matmul(self, node: fx.node.Node) -> relax.Var: ) return res - def _addmm(self, node: fx.node.Node) -> relax.Var: + def _addmm(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] y = self.env[node.args[1]] z = self.env[node.args[2]] @@ -496,7 +516,7 @@ def _addmm(self, node: fx.node.Node) -> relax.Var: res = bias if res is None else self.block_builder.emit(relax.op.add(bias, res)) return res - def _baddbmm(self, node: fx.node.Node) -> relax.Var: + def _baddbmm(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] a = self.env[node.args[1]] b = self.env[node.args[2]] @@ -518,7 +538,7 @@ def _baddbmm(self, node: fx.node.Node) -> relax.Var: res = bias if res is None else self.block_builder.emit(relax.op.add(res, bias)) return res - def _einsum(self, node: fx.node.Node) -> relax.Var: + def _einsum(self, node: fx.Node) -> relax.Var: import torch # type: ignore args = self.retrieve_args(node) @@ -526,7 +546,7 @@ def _einsum(self, node: fx.node.Node) -> relax.Var: return self.block_builder.emit(relax.op.einsum(tuple(args[1]), args[0])) return self.block_builder.emit(relax.op.einsum(args[1:], args[0])) - def _unbind(self, node: fx.node.Node) -> relax.Var: + def _unbind(self, node: fx.Node) -> relax.Var: if len(node.args) == 2: assert isinstance(node.args[1], int), "Expected 2nd argument of unbind as int" dim = node.args[1] @@ -544,12 +564,12 @@ def _unbind(self, node: fx.node.Node) -> relax.Var: ########## Manipulation ########## - def _cat(self, node: fx.node.Node) -> relax.Var: + def _cat(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) axis = args[1] if len(node.args) > 1 else node.kwargs.get("dim", 0) return self.block_builder.emit(relax.op.concat(args[0], axis=axis)) - def _expand(self, node: fx.node.Node) -> relax.Var: + def _expand(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) broadcast_shape, in_shape = [], self.shape_of(args[0]) for idx, i in enumerate(args[1:]): @@ -559,7 +579,7 @@ def _expand(self, node: fx.node.Node) -> relax.Var: broadcast_shape.append(i) return self.block_builder.emit(relax.op.broadcast_to(args[0], broadcast_shape)) - def _flatten(self, node: fx.node.Node) -> relax.Var: + def _flatten(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] if node.target in self.named_modules: module = self.named_modules[node.target] @@ -579,7 +599,7 @@ def _flatten(self, node: fx.node.Node) -> relax.Var: ) return self.block_builder.emit(relax.op.reshape(x, new_shape)) - def _permute(self, node: fx.node.Node) -> relax.Var: + def _permute(self, node: fx.Node) -> relax.Var: import torch # type: ignore args = self.retrieve_args(node) @@ -587,7 +607,7 @@ def _permute(self, node: fx.node.Node) -> relax.Var: return self.block_builder.emit(relax.op.permute_dims(args[0], tuple(args[1]))) return self.block_builder.emit(relax.op.permute_dims(args[0], args[1:])) - def _reshape(self, node: fx.node.Node) -> relax.Var: + def _reshape(self, node: fx.Node) -> relax.Var: import torch # type: ignore args = self.retrieve_args(node) @@ -595,7 +615,7 @@ def _reshape(self, node: fx.node.Node) -> relax.Var: return self.block_builder.emit(relax.op.reshape(args[0], tuple(args[1]))) return self.block_builder.emit(relax.op.reshape(args[0], args[1:])) - def _split(self, node: fx.node.Node) -> relax.Var: + def _split(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] split_size = node.args[1] if "dim" in node.kwargs: @@ -611,7 +631,7 @@ def _split(self, node: fx.node.Node) -> relax.Var: n_section = (self.shape_of(x)[dim].value + split_size - 1) // split_size return self.block_builder.emit(relax.op.split(x, n_section, dim)) - def _chunk(self, node: fx.node.Node) -> relax.Var: + def _chunk(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] chunks = node.args[1] @@ -623,13 +643,13 @@ def _chunk(self, node: fx.node.Node) -> relax.Var: dim = 0 return self.block_builder.emit(relax.op.split(x, chunks, dim)) - def _transpose(self, node: fx.node.Node) -> relax.Var: + def _transpose(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) full_idx = list(range(len(self.shape_of(args[0])))) full_idx[args[1]], full_idx[args[2]] = full_idx[args[2]], full_idx[args[1]] return self.block_builder.emit(relax.op.permute_dims(args[0], full_idx)) - def _squeeze(self, node: fx.node.Node) -> relax.Var: + def _squeeze(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] if "dim" in node.kwargs: @@ -640,7 +660,7 @@ def _squeeze(self, node: fx.node.Node) -> relax.Var: dim = None return self.block_builder.emit(relax.op.squeeze(x, dim)) - def _repeat(self, node: fx.node.Node) -> relax.Var: + def _repeat(self, node: fx.Node) -> relax.Var: import torch # type: ignore args = self.retrieve_args(node) @@ -648,7 +668,7 @@ def _repeat(self, node: fx.node.Node) -> relax.Var: return self.block_builder.emit(relax.op.tile(args[0], tuple(args[1]))) return self.block_builder.emit(relax.op.tile(args[0], args[1:])) - def _tile(self, node: fx.node.Node) -> relax.Var: + def _tile(self, node: fx.Node) -> relax.Var: import torch # type: ignore args = self.retrieve_args(node) @@ -656,7 +676,7 @@ def _tile(self, node: fx.node.Node) -> relax.Var: return self.block_builder.emit(relax.op.tile(args[0], tuple(args[1]))) return self.block_builder.emit(relax.op.tile(args[0], args[1:])) - def _cumsum(self, node: fx.node.Node) -> relax.Var: + def _cumsum(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] if "dim" in node.kwargs: @@ -674,13 +694,13 @@ def _cumsum(self, node: fx.node.Node) -> relax.Var: return self.block_builder.emit(relax.op.cumsum(x, dim, dtype)) - def _index_select(self, node: fx.node.Node) -> relax.Var: + def _index_select(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] dim = node.args[1] index = self.env[node.args[2]] return self.block_builder.emit(relax.op.take(x, index, dim)) - def _masked_fill(self, node: fx.node.Node) -> relax.Var: + def _masked_fill(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] mask = self.env[node.args[1]] value = node.args[2] @@ -688,7 +708,7 @@ def _masked_fill(self, node: fx.node.Node) -> relax.Var: values = self.block_builder.emit(relax.op.full_like(x, rx_value)) return self.block_builder.emit(relax.op.where(mask, values, x)) - def _inplace_masked_fill(self, node: fx.node.Node) -> relax.Var: + def _inplace_masked_fill(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] mask = self.env[node.args[1]] value = node.args[2] @@ -703,7 +723,7 @@ def _inplace_masked_fill(self, node: fx.node.Node) -> relax.Var: def _argmax_argmin(self, op: Callable) -> Callable: from torch import fx - def convert(node: fx.node.Node): + def convert(node: fx.Node): x = self.env[node.args[0]] dim = None keepdims = False @@ -726,14 +746,14 @@ def convert(node: fx.node.Node): ########## Neural Network ########## - def _linear(self, node: fx.node.Node) -> relax.Var: + def _linear(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] module = self.named_modules[node.target] weight = self.params[module.weight] bias = None if module.bias is None else self.params[module.bias] return self.block_builder.emit(relax.op.linear(x, weight, bias, "float32")) - def _linear_functional(self, node: fx.node.Node) -> relax.Var: + def _linear_functional(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) x = args[0] weight = args[1] @@ -770,7 +790,7 @@ def _conv1d_impl( bias = relax.op.reshape(bias, (1, -1, 1)) return self.block_builder.emit(relax.op.add(conv1d, bias)) - def _conv1d(self, node: fx.node.Node) -> relax.Var: + def _conv1d(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] module = self.named_modules[node.target] weight = self.params[module.weight] @@ -788,7 +808,7 @@ def _conv1d(self, node: fx.node.Node) -> relax.Var: groups=module.groups, ) - def _conv1d_functional(self, node: fx.node.Node) -> relax.Var: + def _conv1d_functional(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) x = args[0] weight = args[1] @@ -838,7 +858,7 @@ def _conv1d_transpose_impl( bias = relax.op.reshape(bias, (1, -1, 1)) return self.block_builder.emit(relax.op.add(conv1d_transpose, bias)) - def _conv1d_transpose(self, node: fx.node.Node) -> relax.Var: + def _conv1d_transpose(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] module = self.named_modules[node.target] weight = self.params[module.weight] @@ -856,7 +876,7 @@ def _conv1d_transpose(self, node: fx.node.Node) -> relax.Var: groups=module.groups, ) - def _conv1d_transpose_functional(self, node: fx.node.Node) -> relax.Var: + def _conv1d_transpose_functional(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) x = args[0] weight = args[1] @@ -905,7 +925,7 @@ def _conv2d_impl( bias = relax.op.reshape(bias, (1, -1, 1, 1)) return self.block_builder.emit(relax.op.add(conv2d, bias)) - def _conv2d(self, node: fx.node.Node) -> relax.Var: + def _conv2d(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] module = self.named_modules[node.target] weight = self.params[module.weight] @@ -923,7 +943,7 @@ def _conv2d(self, node: fx.node.Node) -> relax.Var: groups=module.groups, ) - def _conv2d_functional(self, node: fx.node.Node) -> relax.Var: + def _conv2d_functional(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) x = args[0] weight = args[1] @@ -973,7 +993,7 @@ def _conv2d_transpose_impl( bias = relax.op.reshape(bias, (1, -1, 1, 1)) return self.block_builder.emit(relax.op.add(conv2d_transpose, bias)) - def _conv2d_transpose(self, node: fx.node.Node) -> relax.Var: + def _conv2d_transpose(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] module = self.named_modules[node.target] weight = self.params[module.weight] @@ -991,7 +1011,7 @@ def _conv2d_transpose(self, node: fx.node.Node) -> relax.Var: groups=module.groups, ) - def _conv2d_transpose_functional(self, node: fx.node.Node) -> relax.Var: + def _conv2d_transpose_functional(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) x = args[0] weight = args[1] @@ -1040,7 +1060,7 @@ def _conv3d_impl( bias = relax.op.reshape(bias, (1, -1, 1, 1, 1)) return self.block_builder.emit(relax.op.add(conv3d, bias)) - def _conv3d(self, node: fx.node.Node) -> relax.Var: + def _conv3d(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] module = self.named_modules[node.target] weight = self.params[module.weight] @@ -1058,7 +1078,7 @@ def _conv3d(self, node: fx.node.Node) -> relax.Var: groups=module.groups, ) - def _conv3d_functional(self, node: fx.node.Node) -> relax.Var: + def _conv3d_functional(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) x = args[0] weight = args[1] @@ -1077,7 +1097,7 @@ def _conv3d_functional(self, node: fx.node.Node) -> relax.Var: groups=groups, ) - def _max_pool2d(self, node: fx.node.Node) -> relax.Var: + def _max_pool2d(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] if node.target in self.named_modules: module = self.named_modules[node.target] @@ -1108,7 +1128,7 @@ def _max_pool2d(self, node: fx.node.Node) -> relax.Var: ) ) - def _avg_pool2d(self, node: fx.node.Node) -> relax.Var: + def _avg_pool2d(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] if node.target in self.named_modules: module = self.named_modules[node.target] @@ -1154,7 +1174,7 @@ def _avg_pool2d(self, node: fx.node.Node) -> relax.Var: def _adaptive_avg_pool2d(self, is_module: bool) -> Callable: from torch import fx - def _impl(node: fx.node.Node) -> relax.Var: + def _impl(node: fx.Node) -> relax.Var: if is_module: module = self.named_modules[node.target] x = self.env[node.args[0]] @@ -1168,7 +1188,7 @@ def _impl(node: fx.node.Node) -> relax.Var: return _impl - def _softmax(self, node: fx.node.Node) -> relax.Var: + def _softmax(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] if node.target in self.named_modules: module = self.named_modules[node.target] @@ -1179,29 +1199,7 @@ def _softmax(self, node: fx.node.Node) -> relax.Var: assert dim is not None return self.block_builder.emit(relax.op.nn.softmax(x, dim)) - def _log_softmax(self, node: fx.node.Node) -> relax.Var: - x = self.env[node.args[0]] - if node.target in self.named_modules: - module = self.named_modules[node.target] - dim = module.dim - else: - nargs = len(node.args) - dim = node.args[1] if nargs > 1 else node.kwargs["dim"] - assert dim is not None - return self.block_builder.emit(relax.op.nn.log_softmax(x, dim)) - - def _leakyrelu(self, node: fx.node.Node) -> relax.Var: - x = self.env[node.args[0]] - if node.target in self.named_modules: - module = self.named_modules[node.target] - alpha = module.negative_slope - else: - nargs = len(node.args) - alpha = node.args[1] if nargs > 1 else node.kwargs["negative_slope"] - assert alpha is not None - return self.block_builder.emit(relax.op.nn.leakyrelu(x, alpha)) - - def _batch_norm_2d(self, node: fx.node.Node) -> relax.Var: + def _batch_norm_2d(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] module = self.named_modules[node.target] weight = self.params[module.weight] @@ -1224,7 +1222,7 @@ def _batch_norm_2d(self, node: fx.node.Node) -> relax.Var: return self.block_builder.emit(relax.TupleGetItem(res_tuple, 0)) - def _layer_norm(self, node: fx.node.Node) -> relax.Var: + def _layer_norm(self, node: fx.Node) -> relax.Var: import torch # type: ignore from torch.fx.immutable_collections import immutable_list import numpy as np # type: ignore @@ -1291,7 +1289,7 @@ def _layer_norm(self, node: fx.node.Node) -> relax.Var: ) ) - def _group_norm(self, node: fx.node.Node) -> relax.Var: + def _group_norm(self, node: fx.Node) -> relax.Var: import torch # type: ignore x = self.env[node.args[0]] @@ -1317,7 +1315,7 @@ def _group_norm(self, node: fx.node.Node) -> relax.Var: ) ) - def _embedding(self, node: fx.node.Node) -> relax.Var: + def _embedding(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] module = self.named_modules[node.target] weight = self.params[module.weight] @@ -1333,7 +1331,7 @@ def _embedding(self, node: fx.node.Node) -> relax.Var: embedding = self.block_builder.emit(relax.op.take(weight, x, axis=0)) return self.block_builder.emit(relax.op.reshape(embedding, [*x_shape, emb_size])) - def _interpolate(self, node: fx.node.Node) -> relax.Var: + def _interpolate(self, node: fx.Node) -> relax.Var: # torch.nn.functional.interpolate( # input, size=None, scale_factor=None, mode='nearest', align_corners=None, # recompute_scale_factor=None, antialias=False) @@ -1407,7 +1405,7 @@ def _interpolate(self, node: fx.node.Node) -> relax.Var: ) ) - def _cross_entropy(self, node: fx.node.Node) -> relax.Expr: + def _cross_entropy(self, node: fx.Node) -> relax.Expr: preds = self.env[node.args[0]] targets = self.env[node.args[1]] @@ -1442,7 +1440,7 @@ def _cross_entropy(self, node: fx.node.Node) -> relax.Expr: ) ) - def _scaled_dot_product_attention(self, node: fx.node.Node) -> relax.Var: + def _scaled_dot_product_attention(self, node: fx.Node) -> relax.Var: assert ( len(node.args) <= 4 ), "Dropout is not supported, and is_causal should be called by kwargs." @@ -1464,13 +1462,13 @@ def _scaled_dot_product_attention(self, node: fx.node.Node) -> relax.Var: ########## Others ########## - def _sym_size_int(self, node: fx.node.Node) -> relax.Expr: + def _sym_size_int(self, node: fx.Node) -> relax.Expr: x = self.env[node.args[0]] shape = self.shape_of(x) idx = node.args[1] return self.block_builder.emit(relax.const(shape[idx].value, "int32")) - def _size(self, node: fx.node.Node) -> relax.Expr: + def _size(self, node: fx.Node) -> relax.Expr: x = self.env[node.args[0]] shape = self.shape_of(x) if len(node.args) == 1: @@ -1480,7 +1478,7 @@ def _size(self, node: fx.node.Node) -> relax.Expr: idx = node.args[1] return self.shape_of(x)[idx].value - def _getattr(self, node: fx.node.Node) -> relax.Var: + def _getattr(self, node: fx.Node) -> relax.Var: if isinstance(self.env[node.args[0]], relax.Expr): if node.args[1] == "dtype": return self.env[node.args[0]].struct_info.dtype @@ -1488,7 +1486,7 @@ def _getattr(self, node: fx.node.Node) -> relax.Var: return self.shape_of(self.env[node.args[0]]) return getattr(self.env[node.args[0]], node.args[1]) - def _getitem(self, node: fx.node.Node) -> relax.Var: + def _getitem(self, node: fx.Node) -> relax.Var: import torch x = self.env[node.args[0]] @@ -1510,7 +1508,7 @@ def _getitem(self, node: fx.node.Node) -> relax.Var: shape = self.shape_of(x) non_ellipsis_cnt = 0 for index in node.args[1]: - if isinstance(index, (int, slice, torch.fx.node.Node)): + if isinstance(index, (int, slice, torch.fx.Node)): non_ellipsis_cnt += 1 for index in node.args[1]: if isinstance(index, int): @@ -1534,7 +1532,7 @@ def _getitem(self, node: fx.node.Node) -> relax.Var: stride.append(1) stride_axes.append(i) i += 1 - elif isinstance(index, torch.fx.node.Node): + elif isinstance(index, torch.fx.Node): node_index = self.env[index] if not isinstance(node_index, relax.Expr): raise ValueError( @@ -1573,142 +1571,154 @@ def create_convert_map(self): from torch import nn from torch import fx - self.convert_map: Dict[Union[nn.Module, str], Callable[[fx.node.Node], relax.Var]] = { - # call_module - nn.Linear: self._linear, + self.convert_map: Dict[Union[nn.Module, str], Callable[[fx.Node], relax.Var]] = { + ## call_module + # unary + nn.Dropout: lambda node: self.env[node.args[0]], + nn.GELU: self._gelu, + nn.Hardsigmoid: self._hardsigmoid, + nn.Hardswish: self._hardswish, + nn.Identity: lambda node: self.env[node.args[0]], + nn.LeakyReLU: self._leakyrelu_module, + nn.LogSoftmax: self._log_softmax_module, + nn.ReLU: self._unary_op(relax.op.nn.relu), + nn.ReLU6: lambda node: self.block_builder.emit( + relax.op.clip(self.env[node.args[0]], 0, 6) + ), + nn.Sigmoid: self._unary_op(relax.op.sigmoid), + nn.SiLU: self._unary_op(relax.op.nn.silu), + nn.Softmax: self._softmax_module, + nn.Tanh: self._unary_op(relax.op.tanh), + # neural network + nn.AdaptiveAvgPool2d: self._adaptive_avg_pool2d(is_module=True), + nn.AvgPool2d: self._avg_pool2d, + nn.BatchNorm2d: self._batch_norm_2d, nn.Conv1d: self._conv1d, nn.Conv2d: self._conv2d, nn.Conv3d: self._conv3d, nn.ConvTranspose1d: self._conv1d_transpose, nn.ConvTranspose2d: self._conv2d_transpose, - nn.MaxPool2d: self._max_pool2d, - nn.AvgPool2d: self._avg_pool2d, - nn.AdaptiveAvgPool2d: self._adaptive_avg_pool2d(is_module=True), - nn.Softmax: self._softmax, - nn.LogSoftmax: self._log_softmax, - nn.ReLU: lambda node: self.block_builder.emit(relax.op.nn.relu(self.env[node.args[0]])), - nn.LeakyReLU: self._leakyrelu, - nn.ReLU6: lambda node: self.block_builder.emit( - relax.op.clip(self.env[node.args[0]], 0, 6) - ), - nn.GELU: self._gelu, - nn.Sigmoid: self._sigmoid, - nn.Tanh: lambda node: self.block_builder.emit(relax.op.tanh(self.env[node.args[0]])), - nn.SiLU: lambda node: self.block_builder.emit(relax.op.nn.silu(self.env[node.args[0]])), - nn.Hardsigmoid: self._hardsigmoid, - nn.Hardswish: self._hardswish, - nn.Flatten: self._flatten, - nn.BatchNorm2d: self._batch_norm_2d, - nn.LayerNorm: self._layer_norm, + nn.CrossEntropyLoss: self._cross_entropy, nn.GroupNorm: self._group_norm, - nn.Dropout: lambda node: self.env[node.args[0]], - nn.Identity: lambda node: self.env[node.args[0]], + nn.LayerNorm: self._layer_norm, + nn.Linear: self._linear, + nn.MaxPool2d: self._max_pool2d, nn.modules.sparse.Embedding: self._embedding, - nn.CrossEntropyLoss: self._cross_entropy, - # call_function and call_method - "sin": lambda node: self.block_builder.emit(relax.op.sin(self.env[node.args[0]])), - "cos": lambda node: self.block_builder.emit(relax.op.cos(self.env[node.args[0]])), - "tan": lambda node: self.block_builder.emit(relax.op.tan(self.env[node.args[0]])), - "asin": lambda node: self.block_builder.emit(relax.op.asin(self.env[node.args[0]])), - "acos": lambda node: self.block_builder.emit(relax.op.acos(self.env[node.args[0]])), - "atan": lambda node: self.block_builder.emit(relax.op.atan(self.env[node.args[0]])), - "sinh": lambda node: self.block_builder.emit(relax.op.sinh(self.env[node.args[0]])), - "cosh": lambda node: self.block_builder.emit(relax.op.cosh(self.env[node.args[0]])), - "tanh": lambda node: self.block_builder.emit(relax.op.tanh(self.env[node.args[0]])), - "asinh": lambda node: self.block_builder.emit(relax.op.asinh(self.env[node.args[0]])), - "acosh": lambda node: self.block_builder.emit(relax.op.acosh(self.env[node.args[0]])), - "atanh": lambda node: self.block_builder.emit(relax.op.atanh(self.env[node.args[0]])), - "exp": self._exp, - "iadd": self._add, + # tensor manipulation + nn.Flatten: self._flatten, + ## call_function and call_method + # unary + "acos": self._unary_op(relax.op.acos), + "acosh": self._unary_op(relax.op.acosh), + "asin": self._unary_op(relax.op.asin), + "asinh": self._unary_op(relax.op.asinh), + "atan": self._unary_op(relax.op.atan), + "atanh": self._unary_op(relax.op.atanh), + "clamp": self._clamp, + "cos": self._unary_op(relax.op.cos), + "cosh": self._unary_op(relax.op.cosh), + "dropout": lambda node: self.env[node.args[0]], + "exp": self._unary_op(relax.op.exp), + "gelu": self._gelu, + "hardsigmoid": self._hardsigmoid, + "hardswish": self._hardswish, + "leaky_relu": self._leakyrelu, + "log_softmax": self._log_softmax, + "neg": self._unary_op(relax.op.negative), + "relu": self._unary_op(relax.op.nn.relu), + "round": self._round, + "rsqrt": self._unary_op(relax.op.rsqrt), + "sigmoid": self._unary_op(relax.op.sigmoid), + "silu": self._unary_op(relax.op.nn.silu), + "sin": self._unary_op(relax.op.sin), + "sinh": self._unary_op(relax.op.sinh), + "softmax": self._softmax, + "sqrt": self._unary_op(relax.op.sqrt), + "tan": self._unary_op(relax.op.tan), + "tanh": self._unary_op(relax.op.tanh), + "tril_": self._inplace_tril_triu(relax.op.tril), + "tril": self._tril_triu(relax.op.tril), + "triu_": self._inplace_tril_triu(relax.op.triu), + "triu": self._tril_triu(relax.op.triu), + # binary "add": self._add, + "eq": self._eq, "floordiv": self._floordiv, + "iadd": self._add, + "lt": self._lt, + "matmul": self._matmul, + "max": self._max, "mul": self._mul, - "sub": self._sub, "pow": self._pow, - "sigmoid": self._sigmoid, - "sqrt": self._sqrt, - "round": self._round, - "lt": self._lt, - "eq": self._eq, + "sub": self._sub, "truediv": self._truediv, - "fill_": self._inplace_fill, - "new_ones": self._new_ones, - "arange": self._arange, - "empty": self._empty, - "tensor": self._tensor, - "tril": self._tril_triu(relax.op.tril), - "triu": self._tril_triu(relax.op.triu), - "tril_": self._inplace_tril_triu(relax.op.tril), - "triu_": self._inplace_tril_triu(relax.op.triu), - "sum": self._sum, - "float": self._float, - "half": self._half, - "type": self._type, - "astype": self._type, - "matmul": self._matmul, - "conv1d": self._conv1d_functional, + # neural network + "adaptive_avg_pool2d": self._adaptive_avg_pool2d(is_module=False), + "addmm": self._addmm, + "avg_pool2d": self._avg_pool2d, + "baddbmm": self._baddbmm, + "bmm": self._matmul, "conv_transpose1d": self._conv1d_transpose_functional, - "conv2d": self._conv2d_functional, "conv_transpose2d": self._conv2d_transpose_functional, + "conv1d": self._conv1d_functional, + "conv2d": self._conv2d_functional, "conv3d": self._conv3d_functional, + "cross_entropy": self._cross_entropy, + "einsum": self._einsum, + "interpolate": self._interpolate, + "layer_norm": self._layer_norm, "linear": self._linear_functional, - "addmm": self._addmm, - "baddbmm": self._baddbmm, - "bmm": self._matmul, + "max_pool2d": self._max_pool2d, + "scaled_dot_product_attention": self._scaled_dot_product_attention, + "stochastic_depth": lambda node: self.env[node.args[0]], + "unbind": self._unbind, + # statistical + "mean": self._mean, + "sum": self._sum, + # search + "argmax": self._argmax_argmin(relax.op.argmax), + "argmin": self._argmax_argmin(relax.op.argmin), + # tensor manipulation "cat": self._cat, "concat": self._cat, + "contiguous": lambda node: self.env[node.args[0]], + "cumsum": self._cumsum, "expand": self._expand, "flatten": self._flatten, "permute": self._permute, "repeat": self._repeat, "reshape": self._reshape, + "size": self._size, "split": self._split, + "squeeze": self._squeeze, "tile": self._tile, - "cumsum": self._cumsum, - "chunk": self._chunk, "transpose": self._transpose, - "squeeze": self._squeeze, "unsqueeze": lambda node: self.block_builder.emit( relax.op.expand_dims(self.env[node.args[0]], node.args[1]) ), "view": self._reshape, - "argmax": self._argmax_argmin(relax.op.argmax), - "argmin": self._argmax_argmin(relax.op.argmin), - "softmax": self._softmax, - "log_softmax": self._log_softmax, - "dropout": lambda node: self.env[node.args[0]], - "stochastic_depth": lambda node: self.env[node.args[0]], - "clamp": self._clamp, - "relu": lambda node: self.block_builder.emit(relax.op.nn.relu(self.env[node.args[0]])), - "leaky_relu": self._leakyrelu, - "gelu": self._gelu, - "silu": lambda node: self.block_builder.emit(relax.op.nn.silu(self.env[node.args[0]])), - "hardsigmoid": self._hardsigmoid, - "hardswish": self._hardswish, - "interpolate": self._interpolate, - "sym_size.int": self._sym_size_int, - "size": self._size, - "getattr": self._getattr, - "getitem": self._getitem, - "contiguous": lambda node: self.env[node.args[0]], - "to": self._to, - "max_pool2d": self._max_pool2d, - "avg_pool2d": self._avg_pool2d, - "adaptive_avg_pool2d": self._adaptive_avg_pool2d(is_module=False), - "layer_norm": self._layer_norm, + # tensor creation + "arange": self._arange, + "chunk": self._chunk, + "empty": self._empty, + "fill_": self._inplace_fill, + "full": self._full, "index_select": self._index_select, + "masked_fill_": self._inplace_masked_fill, "masked_fill": self._masked_fill, + "new_ones": self._new_ones, "ones": self._ones, - "full": self._full, - "masked_fill_": self._inplace_masked_fill, - "mean": self._mean, - "rsqrt": self._rsqrt, - "neg": self._neg, - "max": self._max, - "cross_entropy": self._cross_entropy, - "scaled_dot_product_attention": self._scaled_dot_product_attention, - "einsum": self._einsum, - "unbind": self._unbind, + "tensor": self._tensor, + "to": self._to, + # datatype + "astype": self._type, + "float": self._float, + "half": self._half, + "type": self._type, + # other + "getattr": self._getattr, + "getitem": self._getitem, + "sym_size.int": self._sym_size_int, } def update_convert_map(self, custom_convert_map: dict):