From 5a590a3f1cc5d12980a382c359bc2b114ae1cef7 Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Thu, 3 Dec 2020 14:16:17 +0100 Subject: [PATCH] Save PyTorch frontend state in object While the functional approach is pretty neat, we ended up having global state (default frontend, dtype) and it'll be more soon (caching of inferred types, see #6900). To not have to pass around the state, this moves the op conversion into a class with instances having the state. --- python/tvm/relay/frontend/pytorch.py | 2013 ++++++++++---------------- 1 file changed, 774 insertions(+), 1239 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 38478e27ff92..4f75cf380cc6 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -17,6 +17,7 @@ # pylint: disable=import-self, too-many-lines, len-as-condition, no-else-return, unused-variable, too-many-nested-blocks # pylint: disable=consider-iterating-dictionary, invalid-name, unused-argument, unused-variable, broad-except # pylint: disable=import-outside-toplevel, simplifiable-if-expression, cell-var-from-loop, unnecessary-lambda +# pylint: disable=missing-function-docstring """PT: PyTorch frontend.""" import itertools import logging @@ -133,16 +134,24 @@ def _is_quantized_tensor(data, prelude): # operator implementation -def _elemwise(name): - def _impl(inputs, input_types): - data0, data1 = _pytorch_promote_types(inputs[:2], input_types[:2]) - return get_relay_op(name)(data0, data1) - return _impl +class PyTorchOpConverter: + """A helper class for holding PyTorch op converters.""" + + def __init__(self, prelude, default_dtype): + self.prelude = prelude + self.default_dtype = default_dtype + self.create_convert_map() + + def make_elemwise(self, name): + def elemwise(inputs, input_types): + data0, data1 = _pytorch_promote_types(inputs[:2], input_types[:2]) + return get_relay_op(name)(data0, data1) + + return elemwise -def _min_max_common(name_elemwise, name_reduce): - def _impl(inputs, input_types): + def min_max_common(self, name_elemwise, name_reduce, inputs, input_types): if len(inputs) == 1: data = _pytorch_promote_types(inputs[:1], input_types[:1]) return get_relay_op(name_reduce)(data[0]) @@ -156,38 +165,27 @@ def _impl(inputs, input_types): data0, data1 = _pytorch_promote_types(inputs[:2], input_types[:2]) return get_relay_op(name_elemwise)(data0, data1) - return _impl - + def max(self, inputs, input_types): + return self.min_max_common("maximum", "max", inputs, input_types) -def _max(): - return _min_max_common("maximum", "max") + def min(self, inputs, input_types): + return self.min_max_common("minimum", "min", inputs, input_types) + def make_unary(self, name): + def unary(inputs, input_types): + # this is just to ensure tensor input + (data,) = _pytorch_promote_types(inputs[:1], input_types[:1]) + return get_relay_op(name)(data) -def _min(): - return _min_max_common("minimum", "min") + return unary - -def _unary(name): - def _impl(inputs, input_types): - # this is just to ensure tensor input - (data,) = _pytorch_promote_types(inputs[:1], input_types[:1]) - return get_relay_op(name)(data) - - return _impl - - -def _log1p(): - def _impl(inputs, input_types): + def log1p(self, inputs, input_types): # 1_plus_log x = log(x + 1) (dtype,) = input_types one = _expr.const(1, dtype=dtype) return _op.log(inputs[0] + one) - return _impl - - -def _arange(): - def _impl(inputs, input_types): + def arange(self, inputs, input_types): def _get_value(val, dtype): # dtype is a tvm dtype if isinstance(val, _expr.Expr): @@ -235,11 +233,7 @@ def _get_type(val, inp_type): return _op.transform.arange(start=start, stop=stop, step=step, dtype=dtype) - return _impl - - -def _squeeze(): - def _impl(inputs, input_types): + def squeeze(self, inputs, input_types): data = inputs[0] if len(inputs) == 1: axis = None @@ -249,33 +243,27 @@ def _impl(inputs, input_types): return _op.transform.squeeze(data, axis) - return _impl - - -def _unsqueeze(): - def _impl(inputs, input_types): + def unsqueeze(self, inputs, input_types): data = inputs[0] axis = inputs[1] return _op.transform.expand_dims(data, int(axis), 1) - return _impl - - -def _concatenate(prelude): - def tensor_array_concat(lst, axis): - assert axis == 0, "Tensor array concat supported only for axis 0" - tensor_array, shape = _convert_to_tensor_array(lst, prelude) - concat_shape = (Any(),) + shape[1:] - concat = prelude.get_global_var_static("tensor_array_concat", "float32", shape) - concatenated = concat(tensor_array) - - static_tensor_array_ops = StaticTensorArrayOps(prelude, "float32", concat_shape) - static_tensor_array_ops.register() - get_tensor = prelude.get_global_var_static("tensor_get_data", "float32", concat_shape) - return get_tensor(concatenated) + def concatenate(self, inputs, input_types): + def tensor_array_concat(lst, axis): + assert axis == 0, "Tensor array concat supported only for axis 0" + tensor_array, shape = _convert_to_tensor_array(lst, self.prelude) + concat_shape = (Any(),) + shape[1:] + concat = self.prelude.get_global_var_static("tensor_array_concat", "float32", shape) + concatenated = concat(tensor_array) + + static_tensor_array_ops = StaticTensorArrayOps(self.prelude, "float32", concat_shape) + static_tensor_array_ops.register() + get_tensor = self.prelude.get_global_var_static( + "tensor_get_data", "float32", concat_shape + ) + return get_tensor(concatenated) - def _impl(inputs, input_types): data = inputs[0] axis = inputs[1] @@ -287,11 +275,7 @@ def _impl(inputs, input_types): return _op.tensor.concatenate(data, int(axis)) - return _impl - - -def _slice(): - def _impl(inputs, input_types): + def slice(self, inputs, input_types): axis_dtype = "int64" index_size_limit = 2 ** 63 - 1 data = inputs[0] @@ -391,11 +375,7 @@ def _impl(inputs, input_types): data, begin=begin, end=end, strides=strides, slice_mode="end" ) - return _impl - - -def _split(): - def _impl(inputs, input_types): + def split(self, inputs, input_types): data = inputs[0] split_size = int(inputs[1]) dim = int(inputs[2]) @@ -408,11 +388,7 @@ def _impl(inputs, input_types): return _op.split(data, indices, dim) - return _impl - - -def _split_with_sizes(): - def _impl(inputs, input_types): + def split_with_sizes(self, inputs, input_types): data = inputs[0] sections = inputs[1] dim = int(inputs[2]) @@ -430,31 +406,19 @@ def _impl(inputs, input_types): return _op.split(data, indices, dim) - return _impl - - -def _select(): - def _impl(inputs, input_types): + def select(self, inputs, input_types): data = inputs[0] dim = int(inputs[1]) index = _wrap_const(inputs[2]) return _op.transform.take(data, index, axis=dim) - return _impl - - -def _take(): - def _impl(inputs, input_types): + def take(self, inputs, input_types): data = inputs[0] indices = _op.cast(inputs[1], "int32") return _op.transform.take(data, indices=indices) - return _impl - - -def _topk(): - def _impl(inputs, input_types): + def topk(self, inputs, input_types): data = inputs[0] axis = int(inputs[2]) is_ascend = not bool(inputs[3]) @@ -473,28 +437,16 @@ def _impl(inputs, input_types): return outs[0], outs[1] - return _impl - - -def _reciprocal(): - def _impl(inputs, input_types): + def reciprocal(self, inputs, input_types): data = inputs[0] return _expr.const(1.0, dtype=input_types[0]) / data - return _impl - - -def _repeat(): - def _impl(inputs, input_types): + def repeat(self, inputs, input_types): data = inputs[0] reps = inputs[1] return _op.transform.tile(data, reps=reps) - return _impl - - -def _repeat_interleave(): - def _impl(inputs, input_types): + def repeat_interleave(self, inputs, input_types): data = inputs[0] if isinstance(inputs[1], int): repeats = inputs[1] @@ -507,77 +459,60 @@ def _impl(inputs, input_types): axis = 0 return _op.transform.repeat(data, repeats=repeats, axis=axis) - return _impl - - -def _addcdiv(): - def _impl(inputs, input_types): + def addcdiv(self, inputs, input_types): data, t1, t2, c = _pytorch_promote_types(inputs[:4], input_types[:4]) return data + (c * (t1 / t2)) - return _impl - - -def _addcmul(): - def _impl(inputs, input_types): + def addcmul(self, inputs, input_types): data, t1, t2, c = _pytorch_promote_types(inputs[:4], input_types[:4]) return data + (c * (t1 * t2)) - return _impl - - -def _where(): - def _impl(inputs, input_types): + def where(self, inputs, input_types): if len(inputs) == 1: - return _nonzero(False)([inputs[0], True], input_types) + return self.nonzero([inputs[0], True], input_types) cond = inputs[0] x, y = _pytorch_promote_types(inputs[1:3], input_types[1:3]) return _op.where(cond, x, y) - return _impl - - -def _full_impl(data, fill_value, dtype): - size = [] - need_reshape = False - new_shape = [] - for dim in data: - if isinstance(dim, _expr.Expr): - if isinstance(dim, _expr.Constant): - dim = int(dim.data.asnumpy()) - if isinstance(size, list): - size.append(dim) - new_shape.append(dim) - else: - dim, success = try_infer_value(dim, lambda ret: int(ret), lambda: 0) - new_shape.append(dim) - - if success: + def full_impl(self, data, fill_value, dtype): + size = [] + need_reshape = False + new_shape = [] + for dim in data: + if isinstance(dim, _expr.Expr): + if isinstance(dim, _expr.Constant): + dim = int(dim.data.asnumpy()) if isinstance(size, list): size.append(dim) + new_shape.append(dim) else: - size = None - need_reshape = True - else: - if isinstance(size, list): - size.append(dim) - new_shape.append(dim) + dim, success = try_infer_value(dim, lambda ret: int(ret), lambda: 0) + new_shape.append(dim) - if size is None: - tmp = [] - for dim in data: - tmp.append(_op.cast(_op.expand_dims(dim, axis=0), "int64")) - size = _op.concatenate(tmp, axis=0) + if success: + if isinstance(size, list): + size.append(dim) + else: + size = None + need_reshape = True + else: + if isinstance(size, list): + size.append(dim) + new_shape.append(dim) - out = _op.full(_expr.const(fill_value), size, dtype=dtype) - if need_reshape: - out = _op.reshape(out, new_shape) - return out + if size is None: + tmp = [] + for dim in data: + tmp.append(_op.cast(_op.expand_dims(dim, axis=0), "int64")) + size = _op.concatenate(tmp, axis=0) + out = _op.full(_expr.const(fill_value), size, dtype=dtype) + if need_reshape: + out = _op.reshape(out, new_shape) + return out -def _ones(default_dtype): - def _impl(inputs, input_types): + def ones(self, inputs, input_types): data = inputs[0] import torch @@ -589,14 +524,10 @@ def _impl(inputs, input_types): if inputs[1] is not None: dtype = _convert_dtype_value(inputs[1]) else: - dtype = default_dtype - return _full_impl(data, 1, dtype) - - return _impl + dtype = self.default_dtype + return self.full_impl(data, 1, dtype) - -def _ones_like(default_dtype): - def _impl(inputs, input_types): + def ones_like(self, inputs, input_types): data = inputs[0] out = _op.ones_like(data) @@ -604,17 +535,13 @@ def _impl(inputs, input_types): if inputs[1] is not None: dtype = _convert_dtype_value(inputs[1]) else: - dtype = default_dtype + dtype = self.default_dtype if input_types[0] != dtype: out = _op.cast(out, dtype) return out - return _impl - - -def _zeros(default_dtype): - def _impl(inputs, input_types): + def zeros(self, inputs, input_types): data = inputs[0] import torch @@ -626,14 +553,10 @@ def _impl(inputs, input_types): if inputs[1] is not None: dtype = _convert_dtype_value(inputs[1]) else: - dtype = default_dtype - return _full_impl(data, 0, dtype) - - return _impl - + dtype = self.default_dtype + return self.full_impl(data, 0, dtype) -def _zeros_like(default_dtype): - def _impl(inputs, input_types): + def zeros_like(self, inputs, input_types): data = inputs[0] out = _op.zeros_like(data) @@ -641,17 +564,13 @@ def _impl(inputs, input_types): if inputs[1] is not None: dtype = _convert_dtype_value(inputs[1]) else: - dtype = default_dtype + dtype = self.default_dtype if input_types[0] not in dtype: out = _op.cast(out, dtype) return out - return _impl - - -def _full(default_dtype): - def _impl(inputs, input_types): + def full(self, inputs, input_types): data = inputs[0] fill_value = inputs[1] @@ -665,15 +584,11 @@ def _impl(inputs, input_types): dtype = _convert_dtype_value(inputs[2]) else: # if dtype is None, torch uses a global default set by torch.set_default_tensor_type() - dtype = default_dtype - - return _full_impl(data, fill_value, dtype) + dtype = self.default_dtype - return _impl + return self.full_impl(data, fill_value, dtype) - -def _full_like(default_dtype): - def _impl(inputs, input_types): + def full_like(self, inputs, input_types): data = inputs[0] fill_value = inputs[1] @@ -684,17 +599,13 @@ def _impl(inputs, input_types): dtype = _convert_dtype_value(inputs[2]) else: # if dtype is None, torch uses a global default set by torch.set_default_tensor_type() - dtype = default_dtype + dtype = self.default_dtype if input_types[0] not in dtype: out = _op.cast(out, dtype) return out - return _impl - - -def _linspace(): - def _impl(inputs, input_types): + def linspace(self, inputs, input_types): start = inputs[0] stop = inputs[1] step = inputs[2] @@ -713,51 +624,31 @@ def _impl(inputs, input_types): return _op.transform.arange(start=start, stop=stop, step=step, dtype=dtype) - return _impl - - -def _relu(prelude): - def _impl(inputs, input_types): + def relu(self, inputs, input_types): data = inputs[0] - if _is_quantized_tensor(data, prelude): + if _is_quantized_tensor(data, self.prelude): assert len(inputs) == 3, "Input quant param not found in op inputs" input_zero_point = _expr.const(inputs[2], dtype="int32") return qnn_torch.quantized_relu(data, input_zero_point) return _op.nn.relu(data) - return _impl - - -def _prelu(): - def _impl(inputs, input_types): + def prelu(self, inputs, input_types): data = inputs[0] alpha = inputs[1] return _op.nn.prelu(data, alpha) - return _impl - - -def _leaky_relu(): - def _impl(inputs, input_types): + def leaky_relu(self, inputs, input_types): data = inputs[0] alpha = float(inputs[1]) return _op.nn.leaky_relu(data, alpha) - return _impl - - -def _elu(): - def _impl(inputs, input_types): + def elu(self, inputs, input_types): data = inputs[0] dtype = input_types[0] alpha = _expr.const(float(inputs[1]), dtype=dtype) return alpha * _op.nn.relu(_expr.const(1, dtype=dtype) - _op.exp(data)) + _op.nn.relu(data) - return _impl - - -def _celu(): - def _impl(inputs, input_types): + def celu(self, inputs, input_types): data = inputs[0] dtype = input_types[0] alpha = _expr.const(float(inputs[1]), dtype=dtype) @@ -765,11 +656,7 @@ def _impl(inputs, input_types): _expr.const(1, dtype=dtype) - _op.exp(data / alpha) ) + _op.nn.relu(data) - return _impl - - -def _gelu(): - def _impl(inputs, input_types): + def gelu(self, inputs, input_types): data = inputs[0] dtype = input_types[0] # gelu is data * normcdf(data) @@ -781,11 +668,7 @@ def _impl(inputs, input_types): + _op.erf(data * _expr.const(0.5 ** 0.5, dtype=dtype)) * _expr.const(0.5, dtype=dtype) ) - return _impl - - -def _selu(): - def _impl(inputs, input_types): + def selu(self, inputs, input_types): data = inputs[0] # https://pytorch.org/docs/stable/nn.html#selu dtype = input_types[0] @@ -795,65 +678,41 @@ def _impl(inputs, input_types): alpha * _op.nn.relu(_expr.const(1.0, dtype=dtype) - _op.exp(data)) + _op.nn.relu(data) ) - return _impl - - -def _log_sigmoid(): - def _impl(inputs, input_types): + def log_sigmoid(self, inputs, input_types): data = inputs[0] return _op.log(_op.tensor.sigmoid(data)) - return _impl - - -def _adaptive_avg_pool_2d(prelude): - def _impl(inputs, input_types): + def adaptive_avg_pool_2d(self, inputs, input_types): data = inputs[0] output_size = inputs[1] def func(x): return _op.nn.adaptive_avg_pool2d(x, output_size=output_size) - if _is_quantized_tensor(data, prelude): + if _is_quantized_tensor(data, self.prelude): return qnn_torch.apply_with_upcast(data, func) return func(data) - return _impl - - -def _adaptive_max_pool_2d(): - def _impl(inputs, input_types): + def adaptive_max_pool_2d(self, inputs, input_types): data = inputs[0] output_size = inputs[1] # returns dummy indices too return _op.nn.adaptive_max_pool2d(data, output_size=output_size), None - return _impl - - -def _adaptive_max_pool_3d(): - def _impl(inputs, input_types): + def adaptive_max_pool_3d(self, inputs, input_types): data = inputs[0] output_size = inputs[1] # returns dummy indices too return _op.nn.adaptive_max_pool3d(data, output_size=output_size), None - return _impl - - -def _adaptive_avg_pool_3d(): - def _impl(inputs, input_types): + def adaptive_avg_pool_3d(self, inputs, input_types): data = inputs[0] output_size = inputs[1] return _op.nn.adaptive_avg_pool3d(data, output_size=output_size) - return _impl - - -def _maxpool_2d(): - def _impl(inputs, input_types): + def maxpool_2d(self, inputs, input_types): data = inputs[0] pool_size = inputs[1] @@ -868,19 +727,11 @@ def _impl(inputs, input_types): return _op.nn.max_pool2d(data, pool_size, strides, padding, "NCHW", ceil_mode) - return _impl - - -def _maxpool_2d_with_indices(): - def _impl(inputs, input_types): + def maxpool_2d_with_indices(self, inputs, input_types): # returns dummy indices too - return _maxpool_2d()(inputs, input_types), None - - return _impl + return self.maxpool_2d(inputs, input_types), None - -def _maxpool_1d(): - def _impl(inputs, input_types): + def maxpool_1d(self, inputs, input_types): data = inputs[0] pool_size = inputs[1] @@ -895,11 +746,7 @@ def _impl(inputs, input_types): return _op.nn.max_pool1d(data, pool_size, strides, padding, "NCW", ceil_mode) - return _impl - - -def _maxpool_3d(): - def _impl(inputs, input_types): + def maxpool_3d(self, inputs, input_types): data = inputs[0] pool_size = inputs[1] @@ -915,21 +762,13 @@ def _impl(inputs, input_types): data, pool_size=pool_size, strides=strides, padding=padding, ceil_mode=ceil_mode ) - return _impl - - -def _hardtanh(): - def _impl(inputs, input_types): + def hardtanh(self, inputs, input_types): a = inputs[0] tanh_min = float(inputs[1]) tanh_max = float(inputs[2]) return _op.tensor.clip(a, tanh_min, tanh_max) - return _impl - - -def _convolution(): - def _impl(inputs, input_types): + def convolution(self, inputs, input_types): # Use transpose or normal use_transpose = True if inputs[6] == 1 else False @@ -1018,11 +857,7 @@ def _impl(inputs, input_types): res = _op.squeeze(res, axis=[2]) return res - return _impl - - -def _softmax(): - def _impl(inputs, input_types): + def softmax(self, inputs, input_types): data = inputs[0] axis = inputs[1] if isinstance(axis, str): @@ -1030,27 +865,15 @@ def _impl(inputs, input_types): return _op.nn.softmax(data, axis=axis) - return _impl - - -def _threshold(): - def _impl(inputs, input_types): + def threshold(self, inputs, input_types): data = inputs[0] return _op.nn.relu(data) - return _impl - - -def _contiguous(): - def _impl(inputs, input_types): + def contiguous(self, inputs, input_types): data = inputs[0] return _op.tensor.copy(data) - return _impl - - -def _batch_norm(): - def _impl(inputs, input_types): + def batch_norm(self, inputs, input_types): data = inputs[0] data_type = input_types[0] @@ -1086,11 +909,7 @@ def _impl(inputs, input_types): scale=scale, )[0] - return _impl - - -def _instance_norm(): - def _impl(inputs, input_types): + def instance_norm(self, inputs, input_types): data = inputs[0] data_type = input_types[0] channels = _infer_shape(data) @@ -1114,28 +933,24 @@ def _impl(inputs, input_types): data, gamma, beta, axis=1, epsilon=epsilon, center=center, scale=scale ) - return _impl - - -def _get_dims(data): - import torch - - if isinstance(data, _expr.Expr): - dims = _infer_shape(data) - elif isinstance(data, list): - dims = data - elif isinstance(data, (torch.Tensor, np.ndarray)): - dims = data.shape - else: - msg = "Data type %s could not be parsed" % type(data) - raise AssertionError(msg) - return dims + @staticmethod + def get_dims(data): + import torch + if isinstance(data, _expr.Expr): + dims = _infer_shape(data) + elif isinstance(data, list): + dims = data + elif isinstance(data, (torch.Tensor, np.ndarray)): + dims = data.shape + else: + msg = "Data type %s could not be parsed" % type(data) + raise AssertionError(msg) + return dims -def _layer_norm(): - def _impl(inputs, input_types): + def layer_norm(self, inputs, input_types): data = inputs[0] - ndims = len(_get_dims(inputs[1])) + ndims = len(self.get_dims(inputs[1])) assert ndims == 1, "Support only normalization over last one dimension." return _op.nn.layer_norm( @@ -1148,11 +963,7 @@ def _impl(inputs, input_types): scale=True, ) - return _impl - - -def _group_norm(): - def _impl(inputs, input_types): + def group_norm(self, inputs, input_types): data = inputs[0] gamma = inputs[2] beta = inputs[3] @@ -1170,17 +981,13 @@ def _impl(inputs, input_types): scale=True, ) - return _impl - - -def _transpose(prelude): - def _impl(inputs, input_types): + def transpose(self, inputs, input_types): data = inputs[0] import torch if isinstance(data, _expr.Expr): - ndims = len(_infer_shape(data, prelude.mod)) + ndims = len(_infer_shape(data, self.prelude.mod)) elif isinstance(data, list): ndims = data elif isinstance(data, (torch.Tensor, np.ndarray)): @@ -1211,11 +1018,7 @@ def _impl(inputs, input_types): axes = inputs[1] return _op.transform.transpose(data, axes) - return _impl - - -def _flatten(): - def _impl(inputs, input_types): + def flatten(self, inputs, input_types): data = inputs[0] start = int(inputs[1]) end = int(inputs[2]) @@ -1237,11 +1040,7 @@ def _impl(inputs, input_types): out = _op.squeeze(out, axis=squeeze_axes) return out - return _impl - - -def _addmm(): - def _impl(inputs, input_types): + def addmm(self, inputs, input_types): input_mat = inputs[0] mat1 = inputs[1] data_type = input_types[1] @@ -1265,35 +1064,24 @@ def _impl(inputs, input_types): return dense_out + input_mat - return _impl - - -def _size(prelude): - def _impl_dynamic(inp, axis): - shape_dynamic = _op.shape_of(inp, dtype="int32") - if axis is not None: - return _op.take(shape_dynamic, _expr.const(axis), 0) - return shape_dynamic - - def _impl(inputs, input_types): - shape = _infer_shape(inputs[0], prelude.mod) + def size(self, inputs, input_types): + shape = _infer_shape(inputs[0], self.prelude.mod) axis = None if len(inputs) > 1: axis = int(inputs[1]) if any(map(lambda s: isinstance(s, tvm.tir.expr.Any), shape)): if axis is None or isinstance(shape[axis], tvm.tir.expr.Any): - return _impl_dynamic(inputs[0], axis) + shape_dynamic = _op.shape_of(inputs[0], dtype="int32") + if axis is not None: + return _op.take(shape_dynamic, _expr.const(axis), 0) + return shape_dynamic if axis is not None: return _expr.const(shape[axis]) return _expr.const(shape) - return _impl - - -def _numtotensor(): - def _impl(inputs, input_types): + def numtotensor(self, inputs, input_types): val = inputs[0] dtype = input_types[0] @@ -1307,18 +1095,10 @@ def _impl(inputs, input_types): arr = val * np.ones([]).astype(dtype) return arr - return _impl - - -def _tensortonum(): - def _impl(inputs, input_types): + def tensortonum(self, inputs, input_types): return inputs[0] - return _impl - - -def _view(): - def _impl(inputs, input_types): + def view(self, inputs, input_types): data = inputs[0] if len(inputs) == 3: @@ -1336,11 +1116,7 @@ def _impl(inputs, input_types): return _op.transform.reshape(data, new_shape) - return _impl - - -def _reshape(): - def _impl(inputs, input_types): + def reshape(self, inputs, input_types): data = inputs[0] new_shape = inputs[1] @@ -1371,11 +1147,7 @@ def _impl(inputs, input_types): new_shape = tmp_shape return _op.transform.reshape(data, new_shape) - return _impl - - -def _pixel_shuffle(prelude): - def _impl(inputs, input_types): + def pixel_shuffle(self, inputs, input_types): data = inputs[0] upscale_factor = inputs[1] upscale_squared = upscale_factor * upscale_factor @@ -1384,7 +1156,7 @@ def _impl(inputs, input_types): c % upscale_squared == 0 ), "input channel should be divisible by square of upscale_factor" - ndims = len(_infer_shape(data, prelude.mod)) + ndims = len(_infer_shape(data, self.prelude.mod)) axes = list(range(ndims)) num_inputs = len(inputs) oc = c // upscale_squared @@ -1402,46 +1174,26 @@ def _impl(inputs, input_types): data = _op.transform.transpose(data, axes) return _op.transform.reshape(data, out_shape) - return _impl - - -def _clone(): - def _impl(inputs, input_types): + def clone(self, inputs, input_types): data = inputs[0] return _op.tensor.copy(data) - return _impl - - -def _log_softmax(): - def _impl(inputs, input_types): + def log_softmax(self, inputs, input_types): data = inputs[0] axis = int(inputs[1]) return _op.nn.log_softmax(data, axis) - return _impl - - -def _sigmoid(): - def _impl(inputs, input_types): + def sigmoid(self, inputs, input_types): data = inputs[0] return _op.tensor.sigmoid(data) - return _impl - - -def _softplus(): - def _impl(inputs, input_types): + def softplus(self, inputs, input_types): data = inputs[0] dtype = input_types[0] beta = _expr.const(float(inputs[1]), dtype=dtype) return _op.log(_op.exp(inputs[0] * beta) + _expr.const(1.0, dtype=dtype)) / beta - return _impl - - -def _avg_pool2d(prelude): - def _impl(inputs, input_types): + def avg_pool2d(self, inputs, input_types): data = inputs[0] pool_size = inputs[1] @@ -1460,16 +1212,12 @@ def func(x): count_include_pad=count_include_pad, ) - if _is_quantized_tensor(data, prelude): + if _is_quantized_tensor(data, self.prelude): return qnn_torch.apply_with_upcast(data, func) return func(data) - return _impl - - -def _avg_pool3d(): - def _impl(inputs, input_types): + def avg_pool3d(self, inputs, input_types): data = inputs[0] pool_size = inputs[1] @@ -1487,41 +1235,32 @@ def _impl(inputs, input_types): count_include_pad=count_include_pad, ) - return _impl - - -def _dropout(): - def _impl(inputs, input_types): + def dropout(self, inputs, input_types): data = inputs[0] rate = float(inputs[1]) return _op.nn.dropout(data, rate) - return _impl - - -def _reduce(name): - def _impl(inputs, input_types): - data = inputs[0] - axis = None - keepdims = False - - if len(inputs) > 2: # default, torch have only data, axis=None, keepdims=False - if isinstance(inputs[1], int): - axis = int(inputs[1]) - elif _is_int_seq(inputs[1]): - axis = inputs[1] - else: - axis = list(_infer_shape(inputs[1])) - keepdims = bool(inputs[2]) + def make_reduce(self, name): + def reduce(inputs, input_types): + data = inputs[0] + axis = None + keepdims = False - return get_relay_op(name)(data, axis=axis, keepdims=keepdims) + if len(inputs) > 2: # default, torch have only data, axis=None, keepdims=False + if isinstance(inputs[1], int): + axis = int(inputs[1]) + elif _is_int_seq(inputs[1]): + axis = inputs[1] + else: + axis = list(_infer_shape(inputs[1])) + keepdims = bool(inputs[2]) - return _impl + return get_relay_op(name)(data, axis=axis, keepdims=keepdims) + return reduce -def _norm(): - def _impl(inputs, input_types): + def norm(self, inputs, input_types): data = inputs[0] dtype = input_types[0] axis = None @@ -1543,11 +1282,7 @@ def _impl(inputs, input_types): reci_order, ) - return _impl - - -def _frobenius_norm(): - def _impl(inputs, input_types): + def frobenius_norm(self, inputs, input_types): data = inputs[0] axis = None keepdims = False @@ -1557,11 +1292,7 @@ def _impl(inputs, input_types): return _op.sqrt(_op.reduce.sum((data * data), axis=axis, keepdims=keepdims)) - return _impl - - -def _std(): - def _impl(inputs, input_types): + def std(self, inputs, input_types): data = inputs[0] if len(inputs) == 2: axis = None @@ -1574,11 +1305,7 @@ def _impl(inputs, input_types): return _op.reduce.std(data, axis=axis, keepdims=keepdims, unbiased=unbiased) - return _impl - - -def _variance(): - def _impl(inputs, input_types): + def variance(self, inputs, input_types): data = inputs[0] if len(inputs) == 2: axis = None @@ -1591,11 +1318,7 @@ def _impl(inputs, input_types): return _op.reduce.variance(data, axis=axis, keepdims=keepdims, unbiased=unbiased) - return _impl - - -def _mean(prelude): - def _impl(inputs, input_types): + def mean(self, inputs, input_types): data = inputs[0] if inputs[1]: @@ -1615,7 +1338,7 @@ def _impl(inputs, input_types): def func(x): return _op.mean(x, axis, keepdims, exclude) - if _is_quantized_tensor(data, prelude): + if _is_quantized_tensor(data, self.prelude): assert len(inputs) == 6, "Input quant param not found in op inputs" input_scale = _expr.const(inputs[4]) input_zero_point = _expr.const(inputs[5]) @@ -1623,18 +1346,14 @@ def func(x): return func(data) - return _impl - - -def _chunk(prelude): - def _impl(inputs, input_types): + def chunk(self, inputs, input_types): data = inputs[0] num_chunks = int(inputs[1]) axis = int(inputs[2]) if isinstance(data, _expr.Expr): - inferred_shape = _infer_shape(data, prelude.mod) + inferred_shape = _infer_shape(data, self.prelude.mod) shape = [] for infer in inferred_shape: @@ -1670,18 +1389,14 @@ def _impl(inputs, input_types): return chunks - return _impl - - -def _matmul(prelude): - def _impl(inputs, input_types): + def matmul(self, inputs, input_types): inputs_0 = inputs[0] inputs_1 = inputs[1] # Need to check input shape as batch matmul must be supported. - a_shape = _infer_shape(inputs_0, prelude.mod) - b_shape = _infer_shape(inputs_1, prelude.mod) + a_shape = _infer_shape(inputs_0, self.prelude.mod) + b_shape = _infer_shape(inputs_1, self.prelude.mod) # When performing a batch matmul, we need to properly handle N-dim shapes. if len(a_shape) > 2 or len(b_shape) > 2: @@ -1689,8 +1404,8 @@ def _impl(inputs, input_types): a = _op.reshape(inputs_0, [-1, a_shape[-2], a_shape[-1]]) b = _op.reshape(inputs_1, [-1, b_shape[-2], b_shape[-1]]) # Broadcast b to match batch size of a - new_b_shape = list(_infer_shape(b, prelude.mod)) - new_a_shape = _infer_shape(a, prelude.mod) + new_b_shape = list(_infer_shape(b, self.prelude.mod)) + new_a_shape = _infer_shape(a, self.prelude.mod) if new_a_shape[0] > new_b_shape[0]: new_b_shape[0] = new_a_shape[0] b = _op.broadcast_to(b, new_b_shape) @@ -1714,11 +1429,7 @@ def _impl(inputs, input_types): return out - return _impl - - -def _expand(): - def _impl(inputs, input_types): + def expand(self, inputs, input_types): data_in = inputs[0] shape = list(_infer_shape(data_in)) @@ -1740,85 +1451,64 @@ def _impl(inputs, input_types): return out - return _impl - - -def _int(): - def _impl(inputs, input_types): + def int(self, inputs, input_types): if isinstance(inputs[0], _expr.Expr): return inputs[0] return int(inputs[0]) - return _impl - - -def _identity(): - def _impl(inputs, input_types): + def identity(self, inputs, input_types): return inputs[0] - return _impl - - -def _none(): - def _impl(inputs, input_types): + def none(self, inputs, input_types): return None - return _impl - - -def _pad(mode): - def _impl(inputs, input_types): - data = inputs[0] - if isinstance(inputs[1], list): - pad_list = inputs[1] - else: - pad_list = list(_infer_shape(inputs[1])) - - # initialize paddings based on input len - pad_len = len(_infer_shape(data)) * 2 - paddings = [0] * pad_len - - if len(pad_list) >= 2: - paddings[-1] = pad_list[1] - paddings[-2] = pad_list[0] - if len(pad_list) >= 4: - paddings[-3] = pad_list[3] - paddings[-4] = pad_list[2] - if len(pad_list) >= 6: - paddings[-5] = pad_list[5] - paddings[-6] = pad_list[4] - - # group into tuple of 2 ints - paddings = [paddings[i : i + 2] for i in range(0, len(paddings), 2)] - - const_paddings = [] - for pad in paddings: - const_paddings.append([]) - for p in pad: - if not isinstance(p, int): - p = int(_infer_value(p, {}).asnumpy()) - const_paddings[-1].append(p) - - if mode == "constant": - return _op.nn.pad(data, const_paddings, pad_value=inputs[2], pad_mode=mode) - else: - return _op.nn.pad(data, const_paddings, pad_mode=mode) - - return _impl + def make_pad(self, mode): + def pad(inputs, input_types): + data = inputs[0] + if isinstance(inputs[1], list): + pad_list = inputs[1] + else: + pad_list = list(_infer_shape(inputs[1])) + + # initialize paddings based on input len + pad_len = len(_infer_shape(data)) * 2 + paddings = [0] * pad_len + + if len(pad_list) >= 2: + paddings[-1] = pad_list[1] + paddings[-2] = pad_list[0] + if len(pad_list) >= 4: + paddings[-3] = pad_list[3] + paddings[-4] = pad_list[2] + if len(pad_list) >= 6: + paddings[-5] = pad_list[5] + paddings[-6] = pad_list[4] + + # group into tuple of 2 ints + paddings = [paddings[i : i + 2] for i in range(0, len(paddings), 2)] + + const_paddings = [] + for pad in paddings: + const_paddings.append([]) + for p in pad: + if not isinstance(p, int): + p = int(_infer_value(p, {}).asnumpy()) + const_paddings[-1].append(p) + + if mode == "constant": + return _op.nn.pad(data, const_paddings, pad_value=inputs[2], pad_mode=mode) + else: + return _op.nn.pad(data, const_paddings, pad_mode=mode) + return pad -def _clamp(): - def _impl(inputs, input_types): + def clamp(self, inputs, input_types): data = inputs[0] amin = inputs[1] if inputs[1] else np.finfo(np.float32).min amax = inputs[2] if inputs[2] else np.finfo(np.float32).max return _op.clip(data, amin, amax) - return _impl - - -def _to(): - def _impl(inputs, input_types): + def to(self, inputs, input_types): data = inputs[0] dtype = inputs[1] if inputs[1] is not None and not isinstance(inputs[1], str) else inputs[2] # special handling for aten::to(data, 6, _, _, _) case @@ -1844,87 +1534,81 @@ def _impl(inputs, input_types): return ret - return _impl - - -def _get_upsample_out_size(inputs, method): - # This assumes a static shape - out_size = [] - if inputs[1] is not None: - for size in inputs[1]: - if not isinstance(size, int): - out_size.append(int(_infer_value(size, {}).asnumpy())) - else: - out_size.append(size) - else: - scale_index = 3 if method in ["bilinear", "trilinear"] else 2 - scales = inputs[scale_index] - assert scales is not None, "neither out size nor scale provided" - assert isinstance(scales, list) - ishape = _infer_shape(inputs[0]) - for i, scale in enumerate(scales): - out_size.append(int(math.floor(float(ishape[2 + i]) * scale))) - - return out_size - - -def _upsample(method, prelude): - def _impl(inputs, input_types): - data = inputs[0] - out_size = _get_upsample_out_size(inputs, method) - - if len(inputs) > 2 and method == "bilinear": - align_corners = inputs[2] - else: - align_corners = False - - if method == "nearest_neighbor": - coord_trans = "asymmetric" - elif align_corners: - coord_trans = "align_corners" + @staticmethod + def get_upsample_out_size(inputs, method): + # This assumes a static shape + out_size = [] + if inputs[1] is not None: + for size in inputs[1]: + if not isinstance(size, int): + out_size.append(int(_infer_value(size, {}).asnumpy())) + else: + out_size.append(size) else: - coord_trans = "half_pixel" - - def func(x): - return _op.image.resize(x, out_size, "NCHW", method, coord_trans) + scale_index = 3 if method in ["bilinear", "trilinear"] else 2 + scales = inputs[scale_index] + assert scales is not None, "neither out size nor scale provided" + assert isinstance(scales, list) + ishape = _infer_shape(inputs[0]) + for i, scale in enumerate(scales): + out_size.append(int(math.floor(float(ishape[2 + i]) * scale))) + + return out_size + + def make_upsample(self, method): + def upsample(inputs, input_types): + data = inputs[0] + out_size = self.get_upsample_out_size(inputs, method) + + if len(inputs) > 2 and method == "bilinear": + align_corners = inputs[2] + else: + align_corners = False - if _is_quantized_tensor(data, prelude): - # input qparams are manually appended by us - assert isinstance(inputs[-2], float) - assert isinstance(inputs[-1], int) - input_scale = _expr.const(inputs[-2]) - input_zero_point = _expr.const(inputs[-1]) - return qnn_torch.quantized_upsample(data, input_scale, input_zero_point, func) + if method == "nearest_neighbor": + coord_trans = "asymmetric" + elif align_corners: + coord_trans = "align_corners" + else: + coord_trans = "half_pixel" - return func(data) + def func(x): + return _op.image.resize(x, out_size, "NCHW", method, coord_trans) - return _impl + if _is_quantized_tensor(data, self.prelude): + # input qparams are manually appended by us + assert isinstance(inputs[-2], float) + assert isinstance(inputs[-1], int) + input_scale = _expr.const(inputs[-2]) + input_zero_point = _expr.const(inputs[-1]) + return qnn_torch.quantized_upsample(data, input_scale, input_zero_point, func) + return func(data) -def _upsample3d(method): - def _impl(inputs, input_types): - data = inputs[0] - out_size = _get_upsample_out_size(inputs, method) + return upsample - if len(inputs) > 2 and method == "trilinear": - align_corners = inputs[2] - else: - align_corners = False + def make_upsample3d(self, method): + def upsample3d(inputs, input_types): + data = inputs[0] + out_size = self.get_upsample_out_size(inputs, method) - if method == "nearest_neighbor": - coord_trans = "asymmetric" - elif align_corners: - coord_trans = "align_corners" - else: - coord_trans = "half_pixel" + if len(inputs) > 2 and method == "trilinear": + align_corners = inputs[2] + else: + align_corners = False - return _op.image.resize3d(data, out_size, "NCDHW", method, coord_trans) + if method == "nearest_neighbor": + coord_trans = "asymmetric" + elif align_corners: + coord_trans = "align_corners" + else: + coord_trans = "half_pixel" - return _impl + return _op.image.resize3d(data, out_size, "NCDHW", method, coord_trans) + return upsample3d -def _expand_as(): - def _impl(inputs, input_types): + def expand_as(self, inputs, input_types): target = inputs[1] t0 = _infer_type(inputs[0]).checked_type.dtype t1 = _infer_type(inputs[1]).checked_type.dtype @@ -1932,34 +1616,18 @@ def _impl(inputs, input_types): target = _op.cast(target, t0) return _op.broadcast_to_like(inputs[0], target) - return _impl - - -def _Bool(): - def _impl(inputs, input_types): + def Bool(self, inputs, input_types): assert len(inputs) == 1 return inputs[0] - return _impl - - -def _Float(): - def _impl(inputs, input_types): + def Float(self, inputs, input_types): assert len(inputs) == 1 return _op.cast(inputs[0], "float32") - return _impl - - -def _mm(): - def _impl(inputs, input_types): + def mm(self, inputs, input_types): return _op.nn.dense(inputs[0], inputs[1]) - return _impl - - -def _bitwise_not(): - def _impl(inputs, input_types): + def bitwise_not(self, inputs, input_types): data = inputs[0] # The input tensor must be of integral or Boolean types. # For bool tensors, it computes the logical NOT @@ -1970,11 +1638,7 @@ def _impl(inputs, input_types): return out - return _impl - - -def _bitwise_xor(): - def _impl(inputs, input_types): + def bitwise_xor(self, inputs, input_types): lhs = inputs[0] rhs = inputs[1] lhs = _op.cast(lhs, "bool") if input_types[0] == "bool" else _op.cast(lhs, "int") @@ -1982,91 +1646,55 @@ def _impl(inputs, input_types): return _op.bitwise_xor(lhs, rhs) - return _impl - - -def _logical_not(): - def _impl(inputs, input_types): + def logical_not(self, inputs, input_types): data = _wrap_const(inputs[0]) return _op.logical_not(_op.cast(data, "bool")) - return _impl - - -def _logical_xor(): - def _impl(inputs, input_types): + def logical_xor(self, inputs, input_types): lhs = _op.cast(inputs[0], "bool") rhs = _op.cast(inputs[1], "bool") return _op.logical_xor(lhs, rhs) - return _impl - - -def _list_getitem(prelude): - def _impl(inputs, input_types): - return prelude.nth(inputs[0], _wrap_const(inputs[1])) - - return _impl - - -def _list_len(prelude): - def _impl(inputs, input_types): - return prelude.length(inputs[0]) + def list_getitem(self, inputs, input_types): + return self.prelude.nth(inputs[0], _wrap_const(inputs[1])) - return _impl + def list_len(self, inputs, input_types): + return self.prelude.length(inputs[0]) - -def _type_as(): - def _impl(inputs, input_types): + def type_as(self, inputs, input_types): assert len(inputs) == 2 assert len(input_types) == 2 return _op.cast(inputs[0], input_types[1]) - return _impl - - -def _gather(): - def _impl(inputs, input_types): + def gather(self, inputs, input_types): data = inputs[0] axis = inputs[1] indices = inputs[2] return _op.gather(data, axis, indices) - return _impl - - -def _add(prelude): - # add_ is overloaded for tensor add and list concat - def _impl(inputs, input_types): + def add(self, inputs, input_types): + # add_ is overloaded for tensor add and list concat if input_types[0] == "ListType": - return prelude.concat(inputs[0], inputs[1]) - return _elemwise("add")(inputs, input_types) - - return _impl + return self.prelude.concat(inputs[0], inputs[1]) + return self.make_elemwise("add")(inputs, input_types) - -def _tensor_array_stack(prelude): - def _impl(inputs, input_types): + def tensor_array_stack(self, inputs, input_types): dim = inputs[1] assert dim == 0, "stacking on a dynamic tensor list only supported on a first axis" - tensor_array, shape = _convert_to_tensor_array(inputs[0], prelude) + tensor_array, shape = _convert_to_tensor_array(inputs[0], self.prelude) stacked_shape = (Any(),) + shape - stack = prelude.get_global_var_static("tensor_array_stack", "float32", shape) + stack = self.prelude.get_global_var_static("tensor_array_stack", "float32", shape) stacked = stack(tensor_array) - static_tensor_array_ops = StaticTensorArrayOps(prelude, "float32", stacked_shape) + static_tensor_array_ops = StaticTensorArrayOps(self.prelude, "float32", stacked_shape) static_tensor_array_ops.register() - get_tensor = prelude.get_global_var_static("tensor_get_data", "float32", stacked_shape) + get_tensor = self.prelude.get_global_var_static("tensor_get_data", "float32", stacked_shape) return get_tensor(stacked) - return _impl - - -def _stack(prelude): - def _impl(inputs, input_types): + def stack(self, inputs, input_types): if isinstance(inputs[0], list): # a static python list of tensors dim = inputs[1] @@ -2074,17 +1702,13 @@ def _impl(inputs, input_types): else: # List ADT case assert isinstance(inputs[0], _expr.Expr) - ty = _infer_type_with_prelude(inputs[0], prelude) - list_ty = prelude.mod.get_global_type_var("List") + ty = _infer_type_with_prelude(inputs[0], self.prelude) + list_ty = self.prelude.mod.get_global_type_var("List") msg = "The input list is expected to be List ADT" assert isinstance(ty, tvm.ir.TypeCall) and ty.func == list_ty, msg - return _tensor_array_stack(prelude)(inputs, input_types) - - return _impl - + return self.tensor_array_stack(inputs, input_types) -def _rsub(): - def _impl(inputs, input_types): + def rsub(self, inputs, input_types): data0, data1 = _pytorch_promote_types(inputs[:2], input_types[:2]) # TODO (t-vi): should this also be part of the type promotion? @@ -2093,21 +1717,13 @@ def _impl(inputs, input_types): # note: rsub means data0 and data1 swap places return get_relay_op("subtract")(data1, alpha * data0) - return _impl - - -def _embedding(): - def _impl(inputs, input_types): + def embedding(self, inputs, input_types): weight = inputs[0] indices = inputs[1] return _op.take(weight, indices.astype("int32"), axis=0) - return _impl - - -def _one_hot(): - def _impl(inputs, input_types): + def one_hot(self, inputs, input_types): indices = inputs[0].astype("int32") num_classes = inputs[1] if num_classes == -1: @@ -2120,28 +1736,16 @@ def _impl(inputs, input_types): return _op.one_hot(indices, on_value, off_value, num_classes, -1, dtype) - return _impl - - -def _index(): - def _impl(inputs, input_types): + def index(self, inputs, input_types): data = inputs[0] indices = inputs[1] return _op.adv_index([data] + indices) - return _impl - - -def _meshgrid(): - def _impl(inputs, input_types): + def meshgrid(self, inputs, input_types): data = inputs[0] return _op.meshgrid(data, indexing="ij") - return _impl - - -def _nms(prelude): - def _impl(inputs, input_types): + def nms(self, inputs, input_types): boxes = inputs[0] scores = inputs[1] iou_threshold = inputs[2] @@ -2187,11 +1791,7 @@ def _impl(inputs, input_types): # in torchvision, indices from nms are int64 return _op.cast(ret, "int64") - return _impl - - -def _logsumexp(): - def _impl(inputs, input_types): + def logsumexp(self, inputs, input_types): data = _pytorch_promote_types(inputs[:1], input_types[:1]) dim_list = inputs[1] keepdim = inputs[2] if len(inputs) > 2 else False @@ -2199,11 +1799,7 @@ def _impl(inputs, input_types): assert isinstance(dim_list, list), "dim is expected to be a list" return _op.logsumexp(data[0], axis=dim_list, keepdims=keepdim) - return _impl - - -def _roi_align(prelude): - def _impl(inputs, input_types): + def roi_align(self, inputs, input_types): data = inputs[0] boxes = inputs[1] @@ -2217,16 +1813,12 @@ def _impl(inputs, input_types): return _op.vision.roi_align(data, boxes, output_size, spatial_scale, sample_ratio) - return _impl - - -def _unbind(): - def _impl(inputs, input_types): + def unbind(self, inputs, input_types): data = inputs[0] dim = int(inputs[1]) ishapes = _infer_shape(data) if dim >= len(ishapes): - msg = "Please check input dim, it shouldn't" "be greater than or equal to rank." + msg = "Please check input dim, it shouldn't be greater than or equal to rank." raise AttributeError(msg) selections = ishapes[dim] @@ -2239,13 +1831,9 @@ def _impl(inputs, input_types): ret = _expr.TupleWrapper(_expr.Tuple(ret), selections) return ret - return _impl - - -def _shape_as_tensor(prelude): - def _impl(inputs, input_types): + def shape_as_tensor(self, inputs, input_types): is_symbolic_shape = False - input_shape = _infer_shape(inputs[0], prelude.mod) + input_shape = _infer_shape(inputs[0], self.prelude.mod) for axis in input_shape: if not isinstance(axis, (int, tvm.tir.IntImm)): is_symbolic_shape = True @@ -2258,45 +1846,30 @@ def _impl(inputs, input_types): return ret - return _impl - - -def _logical_and(): - def _impl(inputs, input_types): + def logical_and(self, inputs, input_types): lhs = _op.cast(inputs[0], "bool") rhs = _op.cast(inputs[1], "bool") return _op.logical_and(lhs, rhs) - return _impl - - -def _nonzero(is_numpy_style): - def _impl(inputs, input_types): + def nonzero(self, inputs, input_types, is_numpy_style=False): data = inputs[0] ret = _op.transform.argwhere(data) - if is_numpy_style or (len(inputs) > 1 and inputs[1]): - return _unbind()([ret, 1], None) - + return self.unbind([ret, 1], None) return ret - return _impl - + def nonzero_numpy(self, inputs, input_types): + return self.nonzero(inputs, input_types, is_numpy_style=False) -def _scatter(): - def _impl(inputs, input_types): + def scatter(self, inputs, input_types): data = inputs[0] axis = int(inputs[1]) index = inputs[2] src = inputs[3] return _op.transform.scatter(data, index, src, axis) - return _impl - - -def _scalar_tensor(): - def _impl(inputs, input_types): + def scalar_tensor(self, inputs, input_types): data = inputs[0] cast_map = { 6: "float32", @@ -2309,11 +1882,7 @@ def _impl(inputs, input_types): data = data.data.asnumpy().tolist() return _expr.const(data, cast_map[type_key]) - return _impl - - -def _interpolate(): - def _impl(inputs, input_types): + def interpolate(self, inputs, input_types): if isinstance(inputs[1], _expr.Expr): out_size = inputs[1] elif isinstance(inputs[1], list): @@ -2342,26 +1911,14 @@ def _impl(inputs, input_types): return _op.image.resize(data, out_size, "NCHW", method, coord_trans) - return _impl - - -def _numel(): - def _impl(inputs, input_types): + def numel(self, inputs, input_types): return _op.ndarray_size(inputs[0]) - return _impl - - -def _empty(): - def _impl(inputs, input_types): + def empty(self, inputs, input_types): shape = inputs[0] return _op.zeros(shape, _convert_dtype_value(inputs[1])) - return _impl - - -def _bincount(): - def _impl(inputs, input_types): + def bincount(self, inputs, input_types): data = inputs[0] weights = inputs[1] maximum = _op.max(data) @@ -2377,18 +1934,427 @@ def _impl(inputs, input_types): counts = _op.zeros(_op.reshape(dim, [1]), out_dtype) return _op.scatter_add(counts, data, updates, axis=0) - return _impl - - -def _scatter_add(): - def _impl(inputs, input_types): + def scatter_add(self, inputs, input_types): data = inputs[0] axis = inputs[1] index = inputs[2] src = inputs[3] return _op.scatter_add(data, index, src, axis=axis) - return _impl + # Operator mappings + def create_convert_map(self): + self.convert_map = { + "aten::pixel_shuffle": self.pixel_shuffle, + "aten::device": self.none, + "prim::device": self.none, + "aten::sub": self.make_elemwise("subtract"), + "aten::sub_": self.make_elemwise("subtract"), + "aten::max": self.max, + "aten::min": self.min, + "aten::mul": self.make_elemwise("multiply"), + "aten::mul_": self.make_elemwise("multiply"), + "aten::pow": self.make_elemwise("power"), + "aten::arange": self.arange, + "aten::meshgrid": self.meshgrid, + "aten::div": self.make_elemwise("divide"), + "aten::div_": self.make_elemwise("divide"), + "aten::floor_divide": self.make_elemwise("floor_divide"), + "aten::true_divide": self.make_elemwise("divide"), + "aten::addcdiv": self.addcdiv, + "aten::addcmul": self.addcmul, + "aten::ones": self.ones, + "aten::ones_like": self.ones_like, + "aten::zeros": self.zeros, + "aten::zeros_like": self.zeros_like, + "aten::full": self.full, + "aten::full_like": self.full_like, + "aten::linspace": self.linspace, + "aten::reciprocal": self.reciprocal, + "aten::repeat": self.repeat, + "aten::repeat_interleave": self.repeat_interleave, + "aten::to": self.to, + "aten::squeeze": self.squeeze, + "aten::unsqueeze": self.unsqueeze, + "aten::cat": self.concatenate, + "aten::slice": self.slice, + "aten::split": self.split, + "aten::split_with_sizes": self.split_with_sizes, + "aten::select": self.select, + "aten::take": self.take, + "aten::where": self.where, + "aten::topk": self.topk, + "aten::relu": self.relu, + "aten::relu_": self.relu, + "aten::prelu": self.prelu, + "aten::leaky_relu": self.leaky_relu, + "aten::leaky_relu_": self.leaky_relu, + "aten::elu": self.elu, + "aten::elu_": self.elu, + "aten::celu": self.celu, + "aten::gelu": self.gelu, + "aten::selu": self.selu, + "aten::log_sigmoid": self.log_sigmoid, + "aten::adaptive_avg_pool2d": self.adaptive_avg_pool_2d, + "aten::adaptive_max_pool2d": self.adaptive_max_pool_2d, + "aten::max_pool2d": self.maxpool_2d, + "aten::max_pool2d_with_indices": self.maxpool_2d_with_indices, + "aten::max_pool1d": self.maxpool_1d, + "aten::max_pool3d": self.maxpool_3d, + "aten::hardtanh": self.hardtanh, + "aten::hardtanh_": self.hardtanh, + "aten::_convolution": self.convolution, + "aten::softmax": self.softmax, + "aten::threshold": self.threshold, + "aten::threshold_": self.threshold, + "aten::contiguous": self.contiguous, + "aten::batch_norm": self.batch_norm, + "aten::instance_norm": self.instance_norm, + "aten::layer_norm": self.layer_norm, + "aten::group_norm": self.group_norm, + "aten::transpose": self.transpose, + "aten::transpose_": self.transpose, + "aten::t": self.transpose, + "aten::flatten": self.flatten, + "aten::addmm": self.addmm, + "aten::size": self.size, + "aten::view": self.view, + "aten::reshape": self.reshape, + "aten::clone": self.clone, + "aten::log_softmax": self.log_softmax, + "aten::sigmoid": self.sigmoid, + "aten::softplus": self.softplus, + "aten::avg_pool2d": self.avg_pool2d, + "aten::avg_pool3d": self.avg_pool3d, + "aten::dropout": self.dropout, + "aten::dropout_": self.dropout, + "aten::feature_dropout": self.dropout, + "aten::alpha_dropout": self.dropout, + "aten::mean": self.mean, + "aten::chunk": self.chunk, + "aten::matmul": self.matmul, + "aten::bmm": self.matmul, + "aten::expand": self.expand, + "aten::Int": self.int, + "prim::NumToTensor": self.numtotensor, + "prim::ImplicitTensorToNum": self.tensortonum, + "aten::ScalarImplicit": self.tensortonum, + "aten::constant_pad_nd": self.make_pad("constant"), + "aten::reflection_pad1d": self.make_pad("reflect"), + "aten::reflection_pad2d": self.make_pad("reflect"), + "aten::replication_pad1d": self.make_pad("edge"), + "aten::replication_pad2d": self.make_pad("edge"), + "aten::replication_pad3d": self.make_pad("edge"), + "aten::permute": self.transpose, + "aten::sum": self.make_reduce("sum"), + "aten::prod": self.make_reduce("prod"), + "aten::argmin": self.make_reduce("argmin"), + "aten::argmax": self.make_reduce("argmax"), + "aten::norm": self.norm, + "aten::frobenius_norm": self.frobenius_norm, + "aten::std": self.std, + "aten::var": self.variance, + "aten::abs": self.make_unary("abs"), + "aten::neg": self.make_unary("negative"), + "aten::cos": self.make_unary("cos"), + "aten::cosh": self.make_unary("cosh"), + "aten::sin": self.make_unary("sin"), + "aten::sinh": self.make_unary("sinh"), + "aten::tan": self.make_unary("tan"), + "aten::tanh": self.make_unary("tanh"), + "aten::acos": self.make_unary("acos"), + "aten::asin": self.make_unary("asin"), + "aten::atan": self.make_unary("atan"), + "aten::log": self.make_unary("log"), + "aten::log2": self.make_unary("log2"), + "aten::log10": self.make_unary("log10"), + "aten::log1p": self.log1p, + "aten::exp": self.make_unary("exp"), + "aten::erf": self.make_unary("erf"), + "aten::trunc": self.make_unary("trunc"), + "aten::sign": self.make_unary("sign"), + "aten::sqrt": self.make_unary("sqrt"), + "aten::rsqrt": self.make_unary("rsqrt"), + "aten::ceil": self.make_unary("ceil"), + "aten::floor": self.make_unary("floor"), + "aten::round": self.make_unary("round"), + "aten::isfinite": self.make_unary("isfinite"), + "aten::isinf": self.make_unary("isinf"), + "aten::isnan": self.make_unary("isnan"), + "aten::clamp": self.clamp, + "aten::clamp_": self.clamp, + "aten::detach": self.identity, + "aten::upsample_bilinear2d": self.make_upsample("bilinear"), + "aten::upsample_nearest2d": self.make_upsample("nearest_neighbor"), + "aten::upsample_trilinear3d": self.make_upsample3d("trilinear"), + "aten::upsample_nearest3d": self.make_upsample3d("nearest_neighbor"), + "aten::expand_as": self.expand_as, + "aten::lt": self.make_elemwise("less"), + "aten::gt": self.make_elemwise("greater"), + "aten::le": self.make_elemwise("less_equal"), + "aten::ge": self.make_elemwise("greater_equal"), + "aten::ne": self.make_elemwise("not_equal"), + "aten::eq": self.make_elemwise("equal"), + "aten::logical_not": self.logical_not, + "aten::logical_xor": self.logical_xor, + "aten::bitwise_not": self.bitwise_not, + "aten::bitwise_xor": self.bitwise_xor, + "aten::Bool": self.Bool, + "aten::Float": self.Float, + "aten::adaptive_avg_pool3d": self.adaptive_avg_pool_3d, + "aten::adaptive_max_pool3d": self.adaptive_max_pool_3d, + "aten::rsub": self.rsub, + "aten::embedding": self.embedding, + "aten::one_hot": self.one_hot, + "aten::mm": self.matmul, + "aten::add": self.add, + "aten::add_": self.add, + "aten::stack": self.stack, + "aten::__getitem__": self.list_getitem, + "aten::len": self.list_len, + "aten::type_as": self.type_as, + "aten::gather": self.gather, + "aten::index_select": self.select, + "aten::index": self.index, + "torchvision::nms": self.nms, + "aten::logsumexp": self.logsumexp, + "torchvision::roi_align": self.roi_align, + "aten::unbind": self.unbind, + "aten::__and__": self.logical_and, + "aten::_shape_as_tensor": self.shape_as_tensor, + "aten::nonzero": self.nonzero, + "aten::nonzero_numpy": self.nonzero_numpy, + "aten::scatter": self.scatter, + "aten::scalar_tensor": self.scalar_tensor, + "aten::__interpolate": self.interpolate, + "aten::IntImplicit": self.identity, + "aten::tensor": self.identity, # used for example in tensor(1.0) + "aten::numel": self.numel, + "aten::empty": self.empty, + "aten::bincount": self.bincount, + "aten::scatter_add": self.scatter_add, + "aten::__not__": self.logical_not, + } + + def update_convert_map(self, custom_map): + self.convert_map.update(custom_map) + + def report_missing_conversion(self, op_names): + """ Check if all ops in an input graph are supported by TVM """ + known_ops = [ + "prim::Constant", + "prim::GetAttr", + "prim::ListConstruct", + "prim::ListUnpack", + "prim::TupleConstruct", + "prim::TupleUnpack", + "prim::RaiseException", + "prim::If", + "prim::Loop", + ] + known_ops += list(self.convert_map.keys()) + known_ops += list(qnn_torch.convert_map.keys()) + + missing = [op_name for op_name in op_names if op_name not in known_ops] + + if missing: + msg = "The following operators are not implemented: {}".format(missing) + raise NotImplementedError(msg) + + def convert_block(self, block, outputs): + """ Translate Torch "Block", used for prim::If and prim::Loop """ + ops = _get_operator_nodes(block.nodes()) + ret_names = _get_input_names(block.returnNode()) + return self.convert_operators(ops, outputs, ret_names) + + def convert_if(self, if_node, outputs): + """ Translate Torch prim::If to Relay If """ + cond = outputs[if_node.inputsAt(0).debugName()] + blocks = list(if_node.blocks()) + true_branch = self.convert_block(blocks[0], outputs) + false_branch = self.convert_block(blocks[1], outputs) + assert len(true_branch) == 1 and len(false_branch) == 1 + return _expr.If(cond, true_branch[0], false_branch[0]) + + def convert_loop(self, loop_node, outputs): + """ Translate Torch prim::Loop to Relay while_loop """ + + def get_input(index): + ivalue = loop_node.inputsAt(index) + inode = ivalue.node() + if inode.kind() == "prim::Constant": + return _expr.const(_get_constant(inode)) + var_name = ivalue.debugName() + assert var_name in outputs + return _wrap_const(outputs[var_name]) + + # Refer to the spec for prim::Loop below + # https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/OVERVIEW.md#loops + # The first input: %max_trip_count + # The second input: %initial_condition + # The rest of input: loop variables + max_loop_count = get_input(0) + init_cond = get_input(1) + num_loop_var = len(list(loop_node.inputs())) - 2 + init_vals = [get_input(i + 2) for i in range(num_loop_var)] + + # while loop has always max_loop_count being int64 max + # max_loop_count.data (tvm.runtime.NDArray) is -1, so _get_constant again + is_while_loop = ( + isinstance(max_loop_count, _expr.Constant) + and _get_constant(loop_node.inputsAt(0).node()) == sys.maxsize + ) + + if is_while_loop: + loop_iter_dtype = "bool" + # while loop with non input dependent condition such as while i < 10: + # init_cond is int, need to cast to bool to type check + if isinstance(init_cond, _expr.Constant): + init_cond = _op.cast(init_cond, "bool") + init_loop_iter_val = init_cond + else: + loop_iter_dtype = "int32" + # always count from 0 + init_loop_iter_val = _expr.const(0, dtype="int32") + + body_block = list(loop_node.blocks())[0] + block_input_names = _get_input_names(body_block) + num_block_inputs = len(block_input_names) + name_val_pairs = list(zip(block_input_names, [init_loop_iter_val] + init_vals)) + outputs.update(name_val_pairs) + + def get_var(name, val): + if val: + checked_type = _infer_type_with_prelude(val, self.prelude) + if hasattr(checked_type, "shape"): + shape = get_const_tuple(checked_type.shape) + actual_shape = [] + for dim in shape: + if isinstance(dim, int) and dim == 0: + actual_shape.append(Any()) + else: + actual_shape.append(dim) + return _expr.var(name, shape=actual_shape, dtype=checked_type.dtype) + else: + return _expr.var(name, type_annotation=checked_type) + return _expr.var(name) + + loop_iter_var = _expr.var(block_input_names[0], shape=(), dtype=loop_iter_dtype) + loop_vars = [get_var(name, val) for name, val in name_val_pairs[1:]] + + # Add non constant free variables to loop variables to prevent code blow up + # Without this, if there are two for loops in a row, which often happens + # if the outer loop is unrolled, the computation corresponding to the first for loop + # is inlined inside loop body, turning O(N) + O(N) computation into O(N^2). + # This issue was found when converting from Stacked LSTM test. Torch does not add the + # outputof the eariler loop into loop variables of the next loop. + # So the variable corresponding to the first loop output appears free in the second + # loop body. + free_vars = [ + var + for var in _get_free_vars_from_block(body_block) + if var in outputs + and not isinstance(outputs[var], (_expr.Constant, int, float, str)) + and outputs[var] + ] + + prev_outputs = {} + for name in free_vars: + prev_output = outputs[name] + new_loop_var = get_var(name, prev_output) + prev_outputs[name] = prev_output + outputs[name] = new_loop_var + loop_vars.append(new_loop_var) + init_vals.append(prev_output) + + def cond(*current_vals): + i = current_vals[0] + + if is_while_loop: + return _op.equal(i, _expr.const(True, "bool")) + + return _op.less(i, max_loop_count) + + def body(*current_vals): + # Update loop variables using the prev iteration outputs + assert len(current_vals) == num_block_inputs + len(free_vars) + + for (i, val) in enumerate(current_vals): + if i < num_block_inputs: + outputs[block_input_names[i]] = val + else: + outputs[free_vars[i - num_block_inputs]] = val + + block_outputs = self.convert_block(body_block, outputs) + block_outputs += [outputs[name] for name in free_vars] + + if not is_while_loop: + # iter var increment implicit in torch, so do it manually + # for while loop, block_outputs[0] is already a boolean, + # the result of termination check + incr = _expr.const(1, dtype="int32") + block_outputs[0] = current_vals[0] + incr + + return block_outputs + + loop = while_loop(cond, [loop_iter_var] + loop_vars, body) + loop_val = loop(init_loop_iter_val, *init_vals) + + # restore original output values for free vars + outputs.update(prev_outputs) + + # The first element is a loop counter or boolean condition, ignore it + return [_expr.TupleGetItem(loop_val, i + 1) for i in range(num_loop_var)] + + def convert_operators(self, operators, outputs, ret_names): + """ Convert each Torch IR operators to Relay equivalent """ + for node_name, op_node in operators: + operator = op_node.kind() + inputs = _get_op_inputs(op_node, outputs) + + if operator == "prim::Constant": + outputs[node_name] = _get_constant(op_node) + elif operator == "prim::ListConstruct" and _should_construct_dynamic_list(op_node): + outputs[node_name] = _convert_to_list_adt(inputs, self.prelude) + elif operator == "prim::ListConstruct": + # This assumes that no more elements will be appended to this list + # In this case, we keep the Python list + outputs[node_name] = inputs + elif operator == "prim::TupleConstruct": + outputs[node_name] = _expr.Tuple(inputs) + elif operator in ["prim::ListUnpack", "prim::TupleUnpack"]: + assert len(inputs) == 1 + if isinstance(inputs[0], (list, _expr.TupleWrapper)): + unpacked = inputs[0] + else: + unpacked = _unpack_tuple(inputs[0]) + outputs.update(zip(_get_output_names(op_node), unpacked)) + elif operator == "prim::prim::RaiseException": + logging.warning("raising exceptions is ignored") + outputs[node_name] = None + elif operator == "prim::If": + if_out = self.convert_if(op_node, outputs) + outputs[node_name] = if_out + elif operator == "prim::Loop": + loop_out = self.convert_loop(op_node, outputs) + unpacked_names = _get_output_names(op_node) + assert len(loop_out) == len(unpacked_names) + outputs.update(zip(unpacked_names, loop_out)) + else: + relay_op = self.convert_map[operator] + relay_out = relay_op( + inputs, _get_input_types(op_node, outputs, default_dtype=self.default_dtype) + ) + + if isinstance(relay_out, tuple): + # This is for torch operators that return multiple outputs + # See _adaptive_max_2d above for example + out_names = _get_output_names(op_node) + outputs.update(zip(out_names, relay_out)) + else: + assert op_node.outputsSize() == 1 + outputs[node_name] = relay_out + + return [_wrap_const(outputs[ret_name]) for ret_name in ret_names] def _pytorch_result_type(dtypes, non_tensor_inputs): @@ -2544,202 +2510,6 @@ def _wrap_const(c): return c -# Operator mappings -def _get_convert_map(prelude, default_dtype): - convert_map = { - "aten::pixel_shuffle": _pixel_shuffle(prelude), - "aten::device": _none(), - "prim::device": _none(), - "aten::sub": _elemwise("subtract"), - "aten::sub_": _elemwise("subtract"), - "aten::max": _max(), - "aten::min": _min(), - "aten::mul": _elemwise("multiply"), - "aten::mul_": _elemwise("multiply"), - "aten::pow": _elemwise("power"), - "aten::arange": _arange(), - "aten::meshgrid": _meshgrid(), - "aten::div": _elemwise("divide"), - "aten::div_": _elemwise("divide"), - "aten::floor_divide": _elemwise("floor_divide"), - "aten::true_divide": _elemwise("divide"), - "aten::addcdiv": _addcdiv(), - "aten::addcmul": _addcmul(), - "aten::ones": _ones(default_dtype), - "aten::ones_like": _ones_like(default_dtype), - "aten::zeros": _zeros(default_dtype), - "aten::zeros_like": _zeros_like(default_dtype), - "aten::full": _full(default_dtype), - "aten::full_like": _full_like(default_dtype), - "aten::linspace": _linspace(), - "aten::reciprocal": _reciprocal(), - "aten::repeat": _repeat(), - "aten::repeat_interleave": _repeat_interleave(), - "aten::to": _to(), - "aten::squeeze": _squeeze(), - "aten::unsqueeze": _unsqueeze(), - "aten::cat": _concatenate(prelude), - "aten::slice": _slice(), - "aten::split": _split(), - "aten::split_with_sizes": _split_with_sizes(), - "aten::select": _select(), - "aten::take": _take(), - "aten::where": _where(), - "aten::topk": _topk(), - "aten::relu": _relu(prelude), - "aten::relu_": _relu(prelude), - "aten::prelu": _prelu(), - "aten::leaky_relu": _leaky_relu(), - "aten::leaky_relu_": _leaky_relu(), - "aten::elu": _elu(), - "aten::elu_": _elu(), - "aten::celu": _celu(), - "aten::gelu": _gelu(), - "aten::selu": _selu(), - "aten::log_sigmoid": _log_sigmoid(), - "aten::adaptive_avg_pool2d": _adaptive_avg_pool_2d(prelude), - "aten::adaptive_max_pool2d": _adaptive_max_pool_2d(), - "aten::max_pool2d": _maxpool_2d(), - "aten::max_pool2d_with_indices": _maxpool_2d_with_indices(), - "aten::max_pool1d": _maxpool_1d(), - "aten::max_pool3d": _maxpool_3d(), - "aten::hardtanh": _hardtanh(), - "aten::hardtanh_": _hardtanh(), - "aten::_convolution": _convolution(), - "aten::softmax": _softmax(), - "aten::threshold": _threshold(), - "aten::threshold_": _threshold(), - "aten::contiguous": _contiguous(), - "aten::batch_norm": _batch_norm(), - "aten::instance_norm": _instance_norm(), - "aten::layer_norm": _layer_norm(), - "aten::group_norm": _group_norm(), - "aten::transpose": _transpose(prelude), - "aten::transpose_": _transpose(prelude), - "aten::t": _transpose(prelude), - "aten::flatten": _flatten(), - "aten::addmm": _addmm(), - "aten::size": _size(prelude), - "aten::view": _view(), - "aten::reshape": _reshape(), - "aten::clone": _clone(), - "aten::log_softmax": _log_softmax(), - "aten::sigmoid": _sigmoid(), - "aten::softplus": _softplus(), - "aten::avg_pool2d": _avg_pool2d(prelude), - "aten::avg_pool3d": _avg_pool3d(), - "aten::dropout": _dropout(), - "aten::dropout_": _dropout(), - "aten::feature_dropout": _dropout(), - "aten::alpha_dropout": _dropout(), - "aten::mean": _mean(prelude), - "aten::chunk": _chunk(prelude), - "aten::matmul": _matmul(prelude), - "aten::bmm": _matmul(prelude), - "aten::expand": _expand(), - "aten::Int": _int(), - "prim::NumToTensor": _numtotensor(), - "prim::ImplicitTensorToNum": _tensortonum(), - "aten::ScalarImplicit": _tensortonum(), - "aten::constant_pad_nd": _pad("constant"), - "aten::reflection_pad1d": _pad("reflect"), - "aten::reflection_pad2d": _pad("reflect"), - "aten::replication_pad1d": _pad("edge"), - "aten::replication_pad2d": _pad("edge"), - "aten::replication_pad3d": _pad("edge"), - "aten::permute": _transpose(prelude), - "aten::sum": _reduce("sum"), - "aten::prod": _reduce("prod"), - "aten::argmin": _reduce("argmin"), - "aten::argmax": _reduce("argmax"), - "aten::norm": _norm(), - "aten::frobenius_norm": _frobenius_norm(), - "aten::std": _std(), - "aten::var": _variance(), - "aten::abs": _unary("abs"), - "aten::neg": _unary("negative"), - "aten::cos": _unary("cos"), - "aten::cosh": _unary("cosh"), - "aten::sin": _unary("sin"), - "aten::sinh": _unary("sinh"), - "aten::tan": _unary("tan"), - "aten::tanh": _unary("tanh"), - "aten::acos": _unary("acos"), - "aten::asin": _unary("asin"), - "aten::atan": _unary("atan"), - "aten::log": _unary("log"), - "aten::log2": _unary("log2"), - "aten::log10": _unary("log10"), - "aten::log1p": _log1p(), - "aten::exp": _unary("exp"), - "aten::erf": _unary("erf"), - "aten::trunc": _unary("trunc"), - "aten::sign": _unary("sign"), - "aten::sqrt": _unary("sqrt"), - "aten::rsqrt": _unary("rsqrt"), - "aten::ceil": _unary("ceil"), - "aten::floor": _unary("floor"), - "aten::round": _unary("round"), - "aten::isfinite": _unary("isfinite"), - "aten::isinf": _unary("isinf"), - "aten::isnan": _unary("isnan"), - "aten::clamp": _clamp(), - "aten::clamp_": _clamp(), - "aten::detach": _identity(), - "aten::upsample_bilinear2d": _upsample("bilinear", prelude), - "aten::upsample_nearest2d": _upsample("nearest_neighbor", prelude), - "aten::upsample_trilinear3d": _upsample3d("trilinear"), - "aten::upsample_nearest3d": _upsample3d("nearest_neighbor"), - "aten::expand_as": _expand_as(), - "aten::lt": _elemwise("less"), - "aten::gt": _elemwise("greater"), - "aten::le": _elemwise("less_equal"), - "aten::ge": _elemwise("greater_equal"), - "aten::ne": _elemwise("not_equal"), - "aten::eq": _elemwise("equal"), - "aten::logical_not": _logical_not(), - "aten::logical_xor": _logical_xor(), - "aten::bitwise_not": _bitwise_not(), - "aten::bitwise_xor": _bitwise_xor(), - "aten::Bool": _Bool(), - "aten::Float": _Float(), - "aten::adaptive_avg_pool3d": _adaptive_avg_pool_3d(), - "aten::adaptive_max_pool3d": _adaptive_max_pool_3d(), - "aten::rsub": _rsub(), - "aten::embedding": _embedding(), - "aten::one_hot": _one_hot(), - "aten::mm": _matmul(prelude), - "aten::add": _add(prelude), - "aten::add_": _add(prelude), - "aten::stack": _stack(prelude), - "aten::__getitem__": _list_getitem(prelude), - "aten::len": _list_len(prelude), - "aten::type_as": _type_as(), - "aten::gather": _gather(), - "aten::index_select": _select(), - "aten::index": _index(), - "torchvision::nms": _nms(prelude), - "aten::logsumexp": _logsumexp(), - "torchvision::roi_align": _roi_align(prelude), - "aten::unbind": _unbind(), - "aten::__and__": _logical_and(), - "aten::_shape_as_tensor": _shape_as_tensor(prelude), - "aten::nonzero": _nonzero(False), - "aten::nonzero_numpy": _nonzero(True), - "aten::scatter": _scatter(), - "aten::scalar_tensor": _scalar_tensor(), - "aten::__interpolate": _interpolate(), - "aten::IntImplicit": _identity(), - "aten::tensor": _identity(), # used for example in tensor(1.0) - "aten::numel": _numel(), - "aten::empty": _empty(), - "aten::bincount": _bincount(), - "aten::scatter_add": _scatter_add(), - "aten::__not__": _logical_not(), - } - return convert_map - - def _run_jit_passes(graph): """ The inline pass is necessary to unwrap prim::CallMethod """ # pylint: disable=c-extension-no-member @@ -2793,29 +2563,6 @@ def _get_users(node): return [use.user for use in _get_uses(node)] -def _report_missing_conversion(op_names, convert_map): - """ Check if all ops in an input graph are supported by TVM """ - known_ops = [ - "prim::Constant", - "prim::GetAttr", - "prim::ListConstruct", - "prim::ListUnpack", - "prim::TupleConstruct", - "prim::TupleUnpack", - "prim::RaiseException", - "prim::If", - "prim::Loop", - ] - known_ops += list(convert_map.keys()) - known_ops += list(qnn_torch.convert_map.keys()) - - missing = [op_name for op_name in op_names if op_name not in known_ops] - - if missing: - msg = "The following operators are not implemented: {}".format(missing) - raise NotImplementedError(msg) - - def _getattr_attr_name(node): attribute_names = node.attributeNames() assert len(attribute_names) == 1 @@ -3117,211 +2864,6 @@ def convert_params(graph, state_dict): return params, param_tensors, packed_param_map -def convert_block(block, outputs, convert_map, prelude, default_dtype="float32"): - """ Translate Torch "Block", used for prim::If and prim::Loop """ - ops = _get_operator_nodes(block.nodes()) - ret_names = _get_input_names(block.returnNode()) - return convert_operators( - ops, outputs, ret_names, convert_map, prelude, default_dtype=default_dtype - ) - - -def convert_if(if_node, outputs, convert_map, prelude, default_dtype="float32"): - """ Translate Torch prim::If to Relay If """ - cond = outputs[if_node.inputsAt(0).debugName()] - blocks = list(if_node.blocks()) - true_branch = convert_block( - blocks[0], outputs, convert_map, prelude, default_dtype=default_dtype - ) - false_branch = convert_block( - blocks[1], outputs, convert_map, prelude, default_dtype=default_dtype - ) - assert len(true_branch) == 1 and len(false_branch) == 1 - return _expr.If(cond, true_branch[0], false_branch[0]) - - -def convert_loop(loop_node, outputs, convert_map, prelude): - """ Translate Torch prim::Loop to Relay while_loop """ - - def get_input(index): - ivalue = loop_node.inputsAt(index) - inode = ivalue.node() - if inode.kind() == "prim::Constant": - return _expr.const(_get_constant(inode)) - var_name = ivalue.debugName() - assert var_name in outputs - return _wrap_const(outputs[var_name]) - - # Refer to the spec for prim::Loop below - # https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/OVERVIEW.md#loops - # The first input: %max_trip_count - # The second input: %initial_condition - # The rest of input: loop variables - max_loop_count = get_input(0) - init_cond = get_input(1) - num_loop_var = len(list(loop_node.inputs())) - 2 - init_vals = [get_input(i + 2) for i in range(num_loop_var)] - - # while loop has always max_loop_count being int64 max - # max_loop_count.data (tvm.runtime.NDArray) is -1, so _get_constant again - is_while_loop = ( - isinstance(max_loop_count, _expr.Constant) - and _get_constant(loop_node.inputsAt(0).node()) == sys.maxsize - ) - - if is_while_loop: - loop_iter_dtype = "bool" - # while loop with non input dependent condition such as while i < 10: - # init_cond is int, need to cast to bool to type check - if isinstance(init_cond, _expr.Constant): - init_cond = _op.cast(init_cond, "bool") - init_loop_iter_val = init_cond - else: - loop_iter_dtype = "int32" - # always count from 0 - init_loop_iter_val = _expr.const(0, dtype="int32") - - body_block = list(loop_node.blocks())[0] - block_input_names = _get_input_names(body_block) - num_block_inputs = len(block_input_names) - name_val_pairs = list(zip(block_input_names, [init_loop_iter_val] + init_vals)) - outputs.update(name_val_pairs) - - def get_var(name, val): - if val: - checked_type = _infer_type_with_prelude(val, prelude) - if hasattr(checked_type, "shape"): - shape = get_const_tuple(checked_type.shape) - actual_shape = [] - for dim in shape: - if isinstance(dim, int) and dim == 0: - actual_shape.append(Any()) - else: - actual_shape.append(dim) - return _expr.var(name, shape=actual_shape, dtype=checked_type.dtype) - else: - return _expr.var(name, type_annotation=checked_type) - return _expr.var(name) - - loop_iter_var = _expr.var(block_input_names[0], shape=(), dtype=loop_iter_dtype) - loop_vars = [get_var(name, val) for name, val in name_val_pairs[1:]] - - # Add non constant free variables to loop variables to prevent code blow up - # Without this, if there are two for loops in a row, which often happens - # if the outer loop is unrolled, the computation corresponding to the first for loop - # is inlined inside loop body, turning O(N) + O(N) computation into O(N^2). - # This issue was found when converting from Stacked LSTM test. Torch does not add the output - # of the eariler loop into loop variables of the next loop. - # So the variable corresponding to the first loop output appears free in the second loop body. - free_vars = [ - var - for var in _get_free_vars_from_block(body_block) - if var in outputs - and not isinstance(outputs[var], (_expr.Constant, int, float, str)) - and outputs[var] - ] - - prev_outputs = {} - for name in free_vars: - prev_output = outputs[name] - new_loop_var = get_var(name, prev_output) - prev_outputs[name] = prev_output - outputs[name] = new_loop_var - loop_vars.append(new_loop_var) - init_vals.append(prev_output) - - def cond(*current_vals): - i = current_vals[0] - - if is_while_loop: - return _op.equal(i, _expr.const(True, "bool")) - - return _op.less(i, max_loop_count) - - def body(*current_vals): - # Update loop variables using the prev iteration outputs - assert len(current_vals) == num_block_inputs + len(free_vars) - - for (i, val) in enumerate(current_vals): - if i < num_block_inputs: - outputs[block_input_names[i]] = val - else: - outputs[free_vars[i - num_block_inputs]] = val - - block_outputs = convert_block(body_block, outputs, convert_map, prelude) - block_outputs += [outputs[name] for name in free_vars] - - if not is_while_loop: - # iter var increment implicit in torch, so do it manually - # for while loop, block_outputs[0] is already a boolean, - # the result of termination check - incr = _expr.const(1, dtype="int32") - block_outputs[0] = current_vals[0] + incr - - return block_outputs - - loop = while_loop(cond, [loop_iter_var] + loop_vars, body) - loop_val = loop(init_loop_iter_val, *init_vals) - - # restore original output values for free vars - outputs.update(prev_outputs) - - # The first element is a loop counter or boolean condition, ignore it - return [_expr.TupleGetItem(loop_val, i + 1) for i in range(num_loop_var)] - - -def convert_operators(operators, outputs, ret_names, convert_map, prelude, default_dtype="float32"): - """ Convert each Torch IR operators to Relay equivalent """ - for node_name, op_node in operators: - operator = op_node.kind() - inputs = _get_op_inputs(op_node, outputs) - - if operator == "prim::Constant": - outputs[node_name] = _get_constant(op_node) - elif operator == "prim::ListConstruct" and _should_construct_dynamic_list(op_node): - outputs[node_name] = _convert_to_list_adt(inputs, prelude) - elif operator == "prim::ListConstruct": - # This assumes that no more elements will be appended to this list - # In this case, we keep the Python list - outputs[node_name] = inputs - elif operator == "prim::TupleConstruct": - outputs[node_name] = _expr.Tuple(inputs) - elif operator in ["prim::ListUnpack", "prim::TupleUnpack"]: - assert len(inputs) == 1 - if isinstance(inputs[0], (list, _expr.TupleWrapper)): - unpacked = inputs[0] - else: - unpacked = _unpack_tuple(inputs[0]) - outputs.update(zip(_get_output_names(op_node), unpacked)) - elif operator == "prim::prim::RaiseException": - logging.warning("raising exceptions is ignored") - outputs[node_name] = None - elif operator == "prim::If": - if_out = convert_if(op_node, outputs, convert_map, prelude, default_dtype=default_dtype) - outputs[node_name] = if_out - elif operator == "prim::Loop": - loop_out = convert_loop(op_node, outputs, convert_map, prelude) - unpacked_names = _get_output_names(op_node) - assert len(loop_out) == len(unpacked_names) - outputs.update(zip(unpacked_names, loop_out)) - else: - relay_op = convert_map[operator] - relay_out = relay_op( - inputs, _get_input_types(op_node, outputs, default_dtype=default_dtype) - ) - - if isinstance(relay_out, tuple): - # This is for torch operators that return multiple outputs - # See _adaptive_max_2d above for example - out_names = _get_output_names(op_node) - outputs.update(zip(out_names, relay_out)) - else: - assert op_node.outputsSize() == 1 - outputs[node_name] = relay_out - - return [_wrap_const(outputs[ret_name]) for ret_name in ret_names] - - def get_all_op_names(graph): """ Return all operator names in the input graph """ nodes = list(graph.nodes()) @@ -3370,16 +2912,16 @@ def from_pytorch(script_module, input_infos, custom_convert_map=None, default_dt mod = tvm.IRModule() prelude = Prelude(mod) - convert_map = _get_convert_map(prelude, default_dtype) + converter = PyTorchOpConverter(prelude, default_dtype) graph = script_module.graph.copy() _run_jit_passes(graph) if custom_convert_map: - convert_map.update(custom_convert_map) + converter.update_convert_map(custom_convert_map) op_names = get_all_op_names(graph) - _report_missing_conversion(op_names, convert_map) + converter.report_missing_conversion(op_names) is_module = isinstance(script_module, torch.jit.ScriptModule) params = script_module.state_dict() if is_module else {} @@ -3399,16 +2941,9 @@ def from_pytorch(script_module, input_infos, custom_convert_map=None, default_dt qnn_torch.add_input_quant_params_to_op_inputs(graph) qnn_torch.add_quant_params_to_outputs(outputs, packed_param_map, weight_quant_params) qnn_torch.add_quant_params(tvm_params, weight_quant_params) - convert_map.update(qnn_torch.convert_map) - - ret = convert_operators( - _get_operator_nodes(graph.nodes()), - outputs, - ret_name, - convert_map, - prelude, - default_dtype=default_dtype, - ) + converter.update_convert_map(qnn_torch.convert_map) + + ret = converter.convert_operators(_get_operator_nodes(graph.nodes()), outputs, ret_name) mod["main"] = tvm.relay.Function(_analysis.free_vars(ret[0]), ret[0])