diff --git a/python/tvm/relay/frontend/caffe.py b/python/tvm/relay/frontend/caffe.py index b8273b0324c0..ab106139becf 100644 --- a/python/tvm/relay/frontend/caffe.py +++ b/python/tvm/relay/frontend/caffe.py @@ -27,7 +27,7 @@ from .. import op as _op from ... import nd as _nd from .common import ExprTable -from .common import infer_shape as _infer_shape +from .common import infer_shape __all__ = ["from_caffe"] @@ -84,8 +84,8 @@ def convert_eltwise(self, op): lhs_expr = self.exp_tab.get_expr(inputs[0]) rhs_expr = self.exp_tab.get_expr(inputs[1]) - lhs_shape = _infer_shape(lhs_expr) - rhs_shape = _infer_shape(rhs_expr) + lhs_shape = infer_shape(lhs_expr) + rhs_shape = infer_shape(rhs_expr) assert lhs_shape == rhs_shape, "input tensors shape should be equal" @@ -163,7 +163,7 @@ def convert_batch_norm(self, op): """Convert BatchNorm layer""" inputs = op.bottom in_expr = self.exp_tab.get_expr(inputs[0]) - n, c, h, w = _infer_shape(in_expr) + n, c, h, w = infer_shape(in_expr) if op.name in self.new_bn: mean, var, eps, gamma, beta = self.new_bn[op.name] @@ -234,7 +234,7 @@ def convert_scale(self, op): np.zeros(gamma.shape, dtype=np.float32), dtype="float32" ) - _, c, _, _ = _infer_shape(in_expr) + _, c, _, _ = infer_shape(in_expr) gamma_expr = _op.reshape(gamma_expr, newshape=(1, c, 1, 1)) beta_expr = _op.reshape(beta_expr, newshape=(1, c, 1, 1)) out = _op.multiply(in_expr, gamma_expr) @@ -262,7 +262,7 @@ def convert_reshape(self, op): dims = list(reshape_param.shape.dim) in_expr = self.exp_tab.get_expr(input_name) - input_shape = list(_infer_shape(in_expr)) + input_shape = list(infer_shape(in_expr)) start_axis = int(reshape_param.axis) if start_axis < 0: @@ -571,7 +571,7 @@ def convert_crop(self, op): offset = list(getattr(crop_params, "offset", 0)) # expand offset to (offset1, offset2, ...) - in_a_shape = _infer_shape(in_expr_a) + in_a_shape = infer_shape(in_expr_a) num_to_crop = len(in_a_shape) - axis if not offset: offset = [0] * num_to_crop diff --git a/python/tvm/relay/frontend/common.py b/python/tvm/relay/frontend/common.py index 077b942ddf01..8b376885df36 100755 --- a/python/tvm/relay/frontend/common.py +++ b/python/tvm/relay/frontend/common.py @@ -474,6 +474,9 @@ def get_name(node): def infer_type(node, mod=None): """A method to infer the type of an intermediate node in the relay graph.""" + if isinstance(node, tvm.relay.Var): + return node.type_annotation + if isinstance(mod, IRModule): mod["main"] = _function.Function(tvm.relay.analysis.free_vars(node), node) mod = _transform.InferType()(mod) @@ -484,11 +487,16 @@ def infer_type(node, mod=None): if mod is not None: new_mod.update(mod) + new_mod = _transform.RemoveUnusedFunctions()(new_mod) new_mod = _transform.InferType()(new_mod) entry = new_mod["main"] ret = entry if isinstance(node, _function.Function) else entry.body - return ret + return ret.checked_type + + +def infer_type_with_prelude(val, prelude): + return infer_type(val, prelude.mod) def fold_constant(node, mod=None): @@ -502,15 +510,14 @@ def infer_channels(inputs, transpose=False): these attributes. We check the shape of weights provided to get the number. """ out_type = infer_type(inputs) - out_shapes = [get_const_tuple(out_type.checked_type.shape)] + out_shapes = [get_const_tuple(out_type.shape)] channels = out_shapes[0][0] if not transpose else out_shapes[0][1] return channels def infer_shape(inputs, mod=None): """A method to get the output type of an intermediate node in the graph.""" - out_type = infer_type(inputs, mod=mod) - checked_type = out_type.checked_type + checked_type = infer_type(inputs, mod=mod) if hasattr(checked_type, "shape"): # Regular operator that outputs tensors return get_const_tuple(checked_type.shape) diff --git a/python/tvm/relay/frontend/coreml.py b/python/tvm/relay/frontend/coreml.py index e515843e5fe2..5853cf623bb5 100644 --- a/python/tvm/relay/frontend/coreml.py +++ b/python/tvm/relay/frontend/coreml.py @@ -29,7 +29,7 @@ from ... import nd as _nd from ..._ffi import base as _base from .common import ExprTable -from .common import infer_shape as _infer_shape +from .common import infer_shape __all__ = ["from_coreml"] @@ -67,7 +67,7 @@ def _ConvolutionLayerParams(op, inexpr, etab): dilation = list(op.dilationFactor) if not dilation: dilation = [1, 1] - N, C, H, W = _infer_shape(inexpr) + N, C, H, W = infer_shape(inexpr) params = { "channels": op.outputChannels, "kernel_size": list(op.kernelSize), diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py index 59b4e99de999..fb633aadb58c 100644 --- a/python/tvm/relay/frontend/mxnet.py +++ b/python/tvm/relay/frontend/mxnet.py @@ -32,8 +32,7 @@ from ... import nd as _nd from .common import StrAttrsDict -from .common import infer_type as _infer_type -from .common import infer_shape as _infer_shape +from .common import infer_type, infer_shape from .common import infer_value as _infer_value from .common import get_name as _get_name from .nnvm_common import _rename, _binop_scalar, _rbinop_scalar, _reduce @@ -72,7 +71,7 @@ def _mx_fully_connected(inputs, attrs): use_flatten = attrs.get_bool("flatten", True) if has_flatten and use_flatten: inputs[0] = _op.nn.batch_flatten(inputs[0]) - data_shape = _infer_type(inputs[0]).checked_type.shape + data_shape = infer_type(inputs[0]).shape if len(data_shape) > 2: inputs[0] = _op.reverse_reshape(inputs[0], [-1, 0]) res = _op.nn.dense(inputs[0], inputs[1], units=units) @@ -119,8 +118,8 @@ def _stable_softrelu(x): def _mx_compare(new_op, wrapper): def impl(inputs, attrs): - expr = _infer_type(inputs[0]) - dtype = expr.checked_type.dtype + expr = infer_type(inputs[0]) + dtype = expr.dtype return wrapper(new_op)(inputs, attrs).astype(dtype) return impl @@ -137,7 +136,7 @@ def _mx_swap_axis(inputs, attrs): assert len(inputs) == 1 dim1 = attrs.get_int("dim1") dim2 = attrs.get_int("dim2") - shape = _infer_type(inputs[0]).checked_type.shape + shape = infer_type(inputs[0]).shape axes = list(range(len(shape))) axes[dim1] = dim2 axes[dim2] = dim1 @@ -418,7 +417,7 @@ def _pool3d(new_op, is_avg): return new_op(inputs[0], **new_attrs) # 3D pooling - if len(_infer_shape(inputs[0])) == 5: + if len(infer_shape(inputs[0])) == 5: if pool_type == "max": if global_pool: return _op.nn.global_max_pool3d(inputs[0]) @@ -513,7 +512,7 @@ def _mx_slice(inputs, attrs): begin = list(attrs.get_int_tuple("begin", None)) end = list(attrs.get_int_tuple("end", None)) stride = attrs.get_int_tuple("step", None) - input_shape = _infer_type(inputs[0]).checked_type.shape + input_shape = infer_type(inputs[0]).shape if begin is None: raise tvm.error.OpAttributeRequired('Attribute "begin" not found in operator Slice.') if end is None: @@ -538,8 +537,8 @@ def _mx_slice_like(inputs, attrs): def _mx_slice_axis(inputs, attrs): assert len(inputs) == 1 - expr = _infer_type(inputs[0]) - shape = expr.checked_type.shape + expr = infer_type(inputs[0]) + shape = expr.shape axis = attrs.get_int("axis") ax_beg = attrs.get_int("begin") ax_end = attrs.get_str("end") @@ -582,8 +581,8 @@ def _mx_crop_like(inputs, attrs): if offset == (0, 0): new_attrs["axes"] = (2, 3) return _op.slice_like(*inputs, **new_attrs) - expr = _infer_type(inputs[1]) - like_shape = expr.checked_type.shape + expr = infer_type(inputs[1]) + like_shape = expr.shape new_attrs["begin"] = [0, 0, offset[0], offset[1]] new_attrs["end"] = [ like_shape[0], @@ -786,8 +785,8 @@ def _mx_multibox_detection(inputs, attrs): def _mx_dot(inputs, attrs): assert len(inputs) == 2 a, b = inputs - rank_a = len(_infer_type(a).checked_type.shape) - rank_b = len(_infer_type(b).checked_type.shape) + rank_a = len(infer_type(a).shape) + rank_b = len(infer_type(b).shape) if rank_a != 2 or rank_b != 2: raise tvm.error.OpAttributeUnimplemented("Only 2-D arrays are supported.") transpose_a = attrs.get_bool("transpose_a", False) @@ -803,12 +802,12 @@ def _mx_dot(inputs, attrs): def _mx_batch_dot(inputs, attrs): assert len(inputs) == 2 a, b = inputs - a_shape = _infer_type(a).checked_type.shape + a_shape = infer_type(a).shape batch_shapes = None if len(a_shape) > 3: batch_shapes = a_shape[:-2] a = _op.reverse_reshape(a, newshape=(-1, 0, 0)) - b_shape = _infer_type(b).checked_type.shape + b_shape = infer_type(b).shape if len(b_shape) > 3: if batch_shapes is None: batch_shapes = b_shape[:-2] @@ -859,7 +858,7 @@ def _mx_contrib_arange_like(inputs, attrs): raise tvm.error.OpAttributeUnimplemented( 'Attribute "repeat" is not supported in operator arange_like.' ) - ty = _infer_type(inputs[0]).checked_type + ty = infer_type(inputs[0]) assert ty shape, dtype = get_const_tuple(ty.shape), ty.dtype axis = attrs.get_int("axis", None) @@ -916,7 +915,7 @@ def _mx_take(inputs, attrs): def _mx_gather_nd(inputs, attrs): assert len(inputs) == 2 - assert len(_infer_shape(inputs[1])) > 1, "index tensor to have at least 2 dimensions" + assert len(infer_shape(inputs[1])) > 1, "index tensor to have at least 2 dimensions" return _op.gather_nd(inputs[0], inputs[1]) @@ -956,8 +955,8 @@ def _mx_resize(inputs, attrs): scale_width = attrs.get_float("scale_width", None) height = attrs.get_int("height", 1) width = attrs.get_int("width", 1) - expr = _infer_type(inputs[0]) - shape = expr.checked_type.shape + expr = infer_type(inputs[0]) + shape = expr.shape if scale_height is not None: height = (scale_height * shape[2]).astype("int32") if scale_width is not None: @@ -968,7 +967,7 @@ def _mx_resize(inputs, attrs): def _mx_amp_multicast(inputs, attrs): cast_narrow = attrs.get_bool("cast_narrow", False) - dtypes = [_infer_type(x).checked_type.dtype for x in inputs] + dtypes = [infer_type(x).dtype for x in inputs] supported_dtypes = ["float16", "float32"] assert all( [x in supported_dtypes for x in dtypes] @@ -989,7 +988,7 @@ def _mx_grid_generator(inputs, attrs): target_shape = attrs.get_int_tuple("target_shape") return _op.image.affine_grid(_op.reshape(inputs[0], (0, 2, 3)), target_shape) if transform_type == "warp": - checked_type = _infer_type(inputs[0]).checked_type + checked_type = infer_type(inputs[0]) batch, _, height, width = get_const_tuple(checked_type.shape) dtype = checked_type.dtype identity_affine = relay.const(np.array([[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]]], dtype=dtype)) @@ -1107,10 +1106,10 @@ def _mx_l2_normalize(inputs, attrs): if mode == "channel": new_attrs["axis"] = [1] elif mode == "instance": - ndim = len(_infer_type(inputs[0]).checked_type.shape) + ndim = len(infer_type(inputs[0]).shape) new_attrs["axis"] = list(range(1, ndim)) elif mode == "spatial": - ndim = len(_infer_type(inputs[0]).checked_type.shape) + ndim = len(infer_type(inputs[0]).shape) new_attrs["axis"] = list(range(2, ndim)) else: raise tvm.error.OpAttributeInvalid( @@ -1172,8 +1171,8 @@ def _mx_broadcast_axis(inputs, attrs): assert len(axis) == len(size) if len(axis) == 0: return inputs[0] - expr = _infer_type(inputs[0]) - src_shape = expr.checked_type.shape + expr = infer_type(inputs[0]) + src_shape = expr.shape tgt_shape = [] for i, dim in enumerate(src_shape): if i not in axis: @@ -1259,9 +1258,9 @@ def _mx_sequence_mask(inputs, attrs): def _mx_contrib_div_sqrt_dim(inputs, _): assert len(inputs) == 1 - ndim = len(_infer_type(inputs[0]).checked_type.shape) + ndim = len(infer_type(inputs[0]).shape) dim = _op.take(_op.shape_of(inputs[0]), _expr.const(ndim - 1, dtype="int32")) - dtype = _infer_type(inputs[0]).checked_type.dtype + dtype = infer_type(inputs[0]).dtype sqrt_dim = _op.sqrt(dim.astype(dtype)) out = inputs[0] / sqrt_dim return out @@ -1280,8 +1279,8 @@ def _rnn_cell(data, states, i2h_weight, h2h_weight, i2h_bias, h2h_bias, activati return out, [out] def _gru_cell(data, states, i2h_weight, h2h_weight, i2h_bias, h2h_bias): - expr = _infer_type(data) - dtype = expr.checked_type.dtype + expr = infer_type(data) + dtype = expr.dtype i2h = _op.nn.bias_add(_op.nn.dense(data, i2h_weight), i2h_bias, axis=-1) h2h = _op.nn.bias_add(_op.nn.dense(states[0], h2h_weight), h2h_bias, axis=-1) i2h_r, i2h_z, i2h = _op.split(i2h, indices_or_sections=3, axis=1) @@ -1324,8 +1323,8 @@ def _lstm_cell(data, states, i2h_weight, h2h_weight, i2h_bias, h2h_bias): seq_data = inputs[0] concat_weight = inputs[1] init_states = inputs[2:] - expr = _infer_type(seq_data) - data_shape = expr.checked_type.shape + expr = infer_type(seq_data) + data_shape = expr.shape seq_len = int(data_shape[0]) assert len(concat_weight) == num_layers * 4 * direct @@ -1531,7 +1530,7 @@ def _mx_cond(inputs, attrs, subgraphs): input_args = [] for i, arg in enumerate(inputs): - var = _expr.var("arg%s" % i, _infer_type(arg).checked_type) + var = _expr.var("arg%s" % i, infer_type(arg)) input_args.append(var) cond_args = [input_args[i] for i in cond_input_locs] then_args = [input_args[i] for i in then_input_locs] @@ -1541,7 +1540,7 @@ def _mx_cond(inputs, attrs, subgraphs): cond_arg_dtype_info = [arg.type_annotation.dtype for arg in cond_args] cond_func = _from_mxnet_impl(subgraphs[0], cond_arg_shapes, cond_arg_dtype_info) cond = _expr.Call(cond_func, cond_args).astype("bool") - cond_shape = get_const_tuple(_infer_type(cond).checked_type.shape) + cond_shape = get_const_tuple(infer_type(cond).shape) if len(cond_shape) > 0: assert len(cond_shape) == 1 and cond_shape[0] == 1, "Condition is not scalar" cond = _op.take(cond, _expr.const(1, "int")) @@ -1591,7 +1590,7 @@ def _qnn_contrib_concat(inputs, attrs): return concat, output_min, output_max else: # Get all dtypes. Find input and output scales, call concatenate. - dtypes = [_infer_type(x).checked_type.dtype for x in input_exprs] + dtypes = [infer_type(x).dtype for x in input_exprs] assert all( [x == "uint8" for x in dtypes] ), "Current support is limited to uint8 inputs only." @@ -1648,8 +1647,8 @@ def _qnn_contrib_quantized_fifo_buffer(inputs, attrs, params): buffer = inputs[1] min_calib_range = inputs[2] max_calib_range = inputs[3] - data_dtype = _infer_type(data).checked_type.dtype - buffer_shape = _infer_shape(buffer) + data_dtype = infer_type(data).dtype + buffer_shape = infer_shape(buffer) buffer_name = _get_name(buffer) params[buffer_name] = _nd.array(np.zeros(buffer_shape).astype(data_dtype)) new_buffer = relay.var(buffer_name, relay.TensorType(buffer_shape, data_dtype)) @@ -1687,7 +1686,7 @@ def _get_data_scale_and_zp(_data, _inputs, _data_min_idx, _data_max_idx): data_min = _inputs[_data_min_idx] data_max = _inputs[_data_max_idx] assert data_min <= data_max - data_dtype = _infer_type(_data).checked_type.dtype + data_dtype = infer_type(_data).dtype assert data_dtype in {"int8", "uint8"} if data_min < 0.0: assert ( @@ -1815,7 +1814,7 @@ def _get_sum(_res, _output_scale, out_dtype): data_sum_min = inputs[-2] data_sum_max = inputs[-1] - data_sum_dtype = _infer_type(data_sum).checked_type.dtype + data_sum_dtype = infer_type(data_sum).dtype data_sum_scale = ( get_mkldnn_uint8_scale(data_sum_min, data_sum_max) if data_sum_dtype == "uint8" @@ -2006,7 +2005,7 @@ def _qnn_dequantize(inputs, attrs): data = inputs[0] input_min = inputs[1] input_max = inputs[2] - in_dtype = _infer_type(data).checked_type.dtype + in_dtype = infer_type(data).dtype result = dequantize_mxnet_min_max(data, input_min, input_max, in_dtype) return result @@ -2026,7 +2025,7 @@ def _qnn_pooling(inputs, attrs): input_min = inputs[1] input_max = inputs[2] data = inputs[0] - data_dtype = _infer_type(data).checked_type.dtype + data_dtype = infer_type(data).dtype pool_type = attrs.get_str("pool_type") if data_dtype in ("int8", "uint8") and pool_type != "max": data = _op.cast(data, "int32") @@ -2043,7 +2042,7 @@ def _qnn_batch_norm(inputs, attrs): # Dequantize the data. data_min_idx, data_max_idx = (-2, -1) data_min, data_max = inputs[data_min_idx], inputs[data_max_idx] - data_dtype = _infer_type(data).checked_type.dtype + data_dtype = infer_type(data).dtype data_scale = ( get_mkldnn_uint8_scale(data_min, data_max) if data_dtype == "uint8" @@ -2083,7 +2082,7 @@ def _get_input_scale_zp(_data_dtype, _inputs, _has_bias): return _data_scale, _data_zp def _get_kernel_scale_zp_tensor_quantized(_kernel, _inputs, _has_bias): - kernel_dtype = _infer_type(_kernel).checked_type.dtype + kernel_dtype = infer_type(_kernel).dtype if kernel_dtype != "int8": raise tvm.error.OpNotImplemented( @@ -2109,7 +2108,7 @@ def _get_kernel_scale_zp_tensor_quantized(_kernel, _inputs, _has_bias): return _kernel_scale, _kernel_zp def _get_kernel_scale_zp_channel_quantized(_kernel, _bias, _data_scale): - kernel_dtype = _infer_type(_kernel).checked_type.dtype + kernel_dtype = infer_type(_kernel).dtype if kernel_dtype != "float32": raise tvm.error.OpNotImplemented( "Channel wise quantized expects weights in float32 data type" @@ -2172,14 +2171,14 @@ def _get_bias_requantize_scale(_inputs, _data_scale, _kernel_scale): ############################## if is_flatten: data = _op.nn.batch_flatten(data) - data_shape = _infer_type(data).checked_type.shape + data_shape = infer_type(data).shape if len(data_shape) > 2: data = _op.reverse_reshape(data, [-1, 0]) ############################### # Get data scale and zero point ############################### - data_dtype = _infer_type(data).checked_type.dtype + data_dtype = infer_type(data).dtype data_scale, data_zp = _get_input_scale_zp(data_dtype, inputs, has_bias) ################################# @@ -2296,7 +2295,7 @@ def _mx_broadcast_like(inputs, attrs): def _mx_logical_not(inputs, input_types): data = inputs[0] - dtype = _infer_type(data).checked_type.dtype + dtype = infer_type(data).dtype data = _op.cast(data, "bool") if dtype != "bool" else data return _op.cast(_op.logical_not(data), dtype) @@ -2304,8 +2303,8 @@ def _mx_logical_not(inputs, input_types): def _mx_broadcast_logical(logical_op): def impl(inputs, input_types): - lhs_type = _infer_type(inputs[0]).checked_type.dtype - rhs_type = _infer_type(inputs[1]).checked_type.dtype + lhs_type = infer_type(inputs[0]).dtype + rhs_type = infer_type(inputs[1]).dtype lhs = _op.cast(inputs[0], "bool") if lhs_type != "bool" else inputs[0] rhs = _op.cast(inputs[1], "bool") if rhs_type != "bool" else inputs[1] @@ -2360,7 +2359,7 @@ def _mx_npx_reshape(inputs, attrs): shape = attrs.get_int_tuple("newshape") reverse = attrs.get_bool("reverse", False) shape_list = list(shape) - old_shape = get_const_tuple(_infer_type(inputs[0]).checked_type.shape) + old_shape = get_const_tuple(infer_type(inputs[0]).shape) new_shape = [] if reverse: old_shape = old_shape[::-1] @@ -2447,9 +2446,9 @@ def _mx_split_v2(inputs, attrs): def _mx_npi_where_rscalar(inputs, attrs): cond, dat = inputs scalar = attrs.get_float("scalar") - cond_shape = get_const_tuple(_infer_type(cond).checked_type.shape) - dat_shape = get_const_tuple(_infer_type(dat).checked_type.shape) - dtype = _infer_type(dat).checked_type.dtype + cond_shape = get_const_tuple(infer_type(cond).shape) + dat_shape = get_const_tuple(infer_type(dat).shape) + dtype = infer_type(dat).dtype # Check for broadcasting out_shape = np.broadcast(np.empty(cond_shape), np.empty(dat_shape)).shape if out_shape != cond_shape: diff --git a/python/tvm/relay/frontend/nnvm_common.py b/python/tvm/relay/frontend/nnvm_common.py index b2537af4b632..a97f408424ab 100644 --- a/python/tvm/relay/frontend/nnvm_common.py +++ b/python/tvm/relay/frontend/nnvm_common.py @@ -22,8 +22,7 @@ from .. import expr as _expr from .. import op as _op from .common import get_relay_op -from .common import infer_type as _infer_type -from .common import infer_shape as _infer_shape +from .common import infer_type, infer_shape def _warn_not_used(attr, op="nnvm"): @@ -74,9 +73,9 @@ def _impl(inputs, attrs, _dtype="float32"): data = inputs[0] length = inputs[1] - data_shape = _infer_shape(data) - data_dtype = _infer_type(data).checked_type.dtype - length_shape = _infer_shape(length) + data_shape = infer_shape(data) + data_dtype = infer_type(data).dtype + length_shape = infer_shape(length) if axis < 0: axis = len(data_shape) + axis @@ -188,7 +187,7 @@ def _impl(inputs, attrs, odtype=None): assert len(inputs) == 1 scalar = attrs.get_float("scalar") if odtype is None: - odtype = _infer_type(inputs[0]).checked_type.dtype + odtype = infer_type(inputs[0]).dtype scalar = _expr.const(scalar, dtype=odtype) return new_op(inputs[0], scalar) @@ -200,7 +199,7 @@ def _impl(inputs, attrs, odtype=None): assert len(inputs) == 1 scalar = attrs.get_float("scalar") if odtype is None: - odtype = _infer_type(inputs[0]).checked_type.dtype + odtype = infer_type(inputs[0]).dtype scalar = _expr.const(scalar, dtype=odtype) return new_op(scalar, inputs[0]) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 0f78c32ef59f..126064f560c2 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -299,7 +299,7 @@ class Pool(OnnxOpConverter): def _impl_v1(cls, inputs, attr, params): data = inputs[0] input_shape = infer_shape(data) - input_dtype = infer_type(data).checked_type.dtype + input_dtype = infer_type(data).dtype ndim = len(input_shape) if "auto_pad" in attr: attr["auto_pad"] = attr["auto_pad"].decode("utf-8") @@ -471,7 +471,7 @@ def _impl_v1(cls, inputs, attr, params): ndim = len(input_shape) kernel_type = infer_type(inputs[1]) - kernel_shapes = [get_const_tuple(kernel_type.checked_type.shape)] + kernel_shapes = [get_const_tuple(kernel_type.shape)] if "kernel_shape" not in attr: attr["kernel_shape"] = kernel_shapes[0][2:] @@ -544,7 +544,7 @@ class ConvTranspose(OnnxOpConverter): def _impl_v1(cls, inputs, attr, params): # get number of channels out_type = infer_type(inputs[1]) - out_shapes = [get_const_tuple(out_type.checked_type.shape)] + out_shapes = [get_const_tuple(out_type.shape)] channels = out_shapes[0][1] attr["channels"] = channels groups = attr.get("group", 1) @@ -600,7 +600,7 @@ def _impl_v1(cls, inputs, attr, params): def _impl_v11(cls, inputs, attr, params): # get number of channels out_type = infer_type(inputs[1]) - out_shapes = [get_const_tuple(out_type.checked_type.shape)] + out_shapes = [get_const_tuple(out_type.shape)] channels = out_shapes[0][1] attr["channels"] = channels groups = attr.get("group", 1) @@ -722,7 +722,7 @@ def _impl_v1(cls, inputs, attr, params): assert len(inputs) == 3 or len(inputs) == 2, "Gemm op take 2 or 3 inputs, {} given".format( len(inputs) ) - dtype = infer_type(inputs[0]).checked_type.dtype + dtype = infer_type(inputs[0]).dtype # Y = alpha * A * B + beta * C alpha = float(attr.get("alpha", 1.0)) beta = float(attr.get("beta", 1.0)) @@ -763,7 +763,7 @@ def flatten_to_nd(x, x_shape, nd=3): return x newshape = _op.concatenate( [ - _expr.const([-1], dtype=infer_type(x_shape).checked_type.dtype), + _expr.const([-1], dtype=infer_type(x_shape).dtype), _op.strided_slice(x_shape, [ndims - nd + 1], [ndims]), ], 0, @@ -773,7 +773,7 @@ def flatten_to_nd(x, x_shape, nd=3): b_type = infer_type(inputs[1]) # Convert to dense if the second matrix is 2d and non-dynamic - if b_rank == 2 and not _ty.is_dynamic(b_type.checked_type): + if b_rank == 2 and not _ty.is_dynamic(b_type): a = flatten_to_nd(inputs[0], a_shape, 2) b = _op.transpose(inputs[1]) output = _op.nn.dense(a, b) @@ -857,7 +857,7 @@ class MaxUnpool(OnnxOpConverter): def _impl_v11(cls, inputs, attr, params): # Unpack inputs and attributes data = inputs[0] - data_type = infer_type(data).checked_type.dtype + data_type = infer_type(data).dtype indices = inputs[1] output_shape = inputs[2] kernel_shape = attr.get("kernel_shape") @@ -911,7 +911,7 @@ class LpPool(OnnxOpConverter): @classmethod def _impl_v1(cls, inputs, attr, params): - dtype = infer_type(inputs[0]).checked_type.dtype + dtype = infer_type(inputs[0]).dtype data = inputs[0] input_shape = infer_shape(data) ndim = len(input_shape) @@ -1063,7 +1063,7 @@ class Reciprocal(OnnxOpConverter): @classmethod def _impl_v1(cls, inputs, attr, params): - dtype = infer_type(inputs[0]).checked_type.dtype + dtype = infer_type(inputs[0]).dtype return _expr.const(1.0, dtype=dtype) / inputs[0] @@ -1171,7 +1171,7 @@ class Shrink(OnnxOpConverter): @classmethod def _impl_v9(cls, inputs, attr, params): x = inputs[0] - dtype = infer_type(x).checked_type.dtype + dtype = infer_type(x).dtype lambd = _op.const(attr.get("lambd", 0.5), dtype=dtype) bias = _op.const(attr.get("bias", 0.0), dtype=dtype) @@ -1323,7 +1323,7 @@ def _impl_v9(cls, inputs, attr, params): def shape_of(x, dtype="int64"): - ttype = infer_type(x).checked_type + ttype = infer_type(x) if not _ty.is_dynamic(ttype): shape = list(ttype.shape) return _expr.const(shape, dtype) @@ -1491,9 +1491,9 @@ def has_static_axes(): # Update the starts and ends according to axes if required. if axes is not None: - data_shape = shape_of(inputs[0], dtype=infer_type(ends).checked_type.dtype) + data_shape = shape_of(inputs[0], dtype=infer_type(ends).dtype) starts = _op.scatter( - _op.const([0] * data_rank, dtype=infer_type(starts).checked_type.dtype), + _op.const([0] * data_rank, dtype=infer_type(starts).dtype), axes, starts, axis=0, @@ -1501,14 +1501,14 @@ def has_static_axes(): ends = _op.scatter(data_shape, axes, ends, axis=0) if steps is not None: steps = _op.scatter( - _op.const([1] * data_rank, dtype=infer_type(steps).checked_type.dtype), + _op.const([1] * data_rank, dtype=infer_type(steps).dtype), axes, steps, axis=0, ) if steps is None: - steps = _op.const([1] * data_rank, dtype=infer_type(starts).checked_type.dtype) + steps = _op.const([1] * data_rank, dtype=infer_type(starts).dtype) return _op.strided_slice( inputs[0], fold_constant(starts), fold_constant(ends), fold_constant(steps) @@ -1517,7 +1517,7 @@ def has_static_axes(): def normalize_gather_indices(data, indices, axis): """Make sure gather indicies aren't negative""" - ind_dtype = infer_type(indices).checked_type.dtype + ind_dtype = infer_type(indices).dtype # Normalize the indices to a positive range s = _op.take(_op.shape_of(data, dtype=ind_dtype), _op.const(axis)) cond = fold_constant(indices < _op.const(0, ind_dtype)) @@ -1603,7 +1603,7 @@ class EyeLike(OnnxOpConverter): @classmethod def _impl_v9(cls, inputs, attr, params): - in_checked_type = infer_type(inputs[0]).checked_type + in_checked_type = infer_type(inputs[0]) in_dtype = in_checked_type.dtype in_shape = list(get_const_tuple(in_checked_type.shape)) dtype = attr.get("dtype", None) @@ -1912,7 +1912,7 @@ def _impl_v1(cls, inputs, attr, params): ndim = len(infer_shape(inputs[0])) if axis < 0: axis += ndim - dtype = infer_type(inputs[0]).checked_type.dtype + dtype = infer_type(inputs[0]).dtype if axis == 0: pre = _op.const([1], "int64") @@ -1948,8 +1948,8 @@ def _impl_v9(cls, inputs, attr, params): # Split onnx on off values into two separate expressions. off_value, on_value = _op.take(values, _op.const(0)), _op.take(values, _op.const(1)) # Extract the datatype of the output from on_value. - dtype = infer_type(on_value).checked_type.dtype - ind_dtype = infer_type(indices).checked_type.dtype + dtype = infer_type(on_value).dtype + ind_dtype = infer_type(indices).dtype # Normalize the indices to a positive range indices = _op.where( indices < _op.const(0, ind_dtype), indices + _op.cast(depth, ind_dtype), indices @@ -2088,7 +2088,7 @@ class Expand(OnnxOpConverter): @classmethod def _impl_v8(cls, inputs, attr, params): - dtype = infer_type(inputs[1]).checked_type.dtype + dtype = infer_type(inputs[1]).dtype in_shape = shape_of(inputs[0], dtype=dtype) shape = inputs[1] @@ -2235,7 +2235,7 @@ def _impl_v7(cls, inputs, attr, params): Pp = inputs[7] num_directions = infer_shape(Wp)[0] - W_dtype = infer_type(Wp).checked_type.dtype + W_dtype = infer_type(Wp).dtype if num_directions not in [1, 2]: raise ValueError("num_directions must be either 1 or 2!") @@ -2413,7 +2413,7 @@ def _impl_v7(cls, inputs, attr, params): linear_before_reset = attr.get("linear_before_reset", 0) num_directions = infer_shape(Wp)[0] - W_dtype = infer_type(Wp).checked_type.dtype + W_dtype = infer_type(Wp).dtype if num_directions not in [1, 2]: raise NotImplementedError( @@ -2515,7 +2515,7 @@ def _impl_v10(cls, inputs, attr, params): ) scale = inputs[1] - size = _op.cast(shape_of(inputs[0]), infer_type(scale).checked_type.dtype) * scale + size = _op.cast(shape_of(inputs[0]), infer_type(scale).dtype) * scale ndims = len(infer_shape(inputs[0])) out = None if ndims == 3: @@ -2560,7 +2560,7 @@ def _impl_v11(cls, inputs, attr, params): size = inputs[3] else: assert len(scale_shape) != 0, "One of scale or size should be passed." - size = _op.cast(shape_of(inputs[0]), infer_type(scale).checked_type.dtype) * scale + size = _op.cast(shape_of(inputs[0]), infer_type(scale).dtype) * scale out_size = fold_constant(_op.strided_slice(size, [2], [4])) out = None if ndims == 3: @@ -2656,9 +2656,7 @@ def _impl_v1(cls, inputs, attr, params): if len(inputs) != 3: raise ValueError("Expect 3 input only") - return _op.arange( - inputs[0], inputs[1], inputs[2], dtype=infer_type(inputs[0]).checked_type.dtype - ) + return _op.arange(inputs[0], inputs[1], inputs[2], dtype=infer_type(inputs[0]).dtype) class IsInf(OnnxOpConverter): @@ -2668,7 +2666,7 @@ class IsInf(OnnxOpConverter): def _impl_v10(cls, inputs, attr, params): detect_negative = attr.get("detect_negative", 1) detect_positive = attr.get("detect_positive", 1) - dtype = infer_type(inputs[0]).checked_type.dtype + dtype = infer_type(inputs[0]).dtype isinf = _op.isinf(inputs[0]) if not detect_negative: isinf = isinf * (inputs[0] > _op.const(0, dtype)) @@ -2712,7 +2710,7 @@ def _impl_v1(cls, inputs, attr, params): spatial_scale = attr.get("spatial_scale", 1.0) batch_indices = _op.expand_dims(batch_indices, axis=1, num_newaxis=1) - batch_indices = _op.cast(batch_indices, infer_type(rois).checked_type.dtype) + batch_indices = _op.cast(batch_indices, infer_type(rois).dtype) rois = _op.concatenate([batch_indices, rois], 1) return _vision.roi_align( @@ -2762,7 +2760,7 @@ class Softplus(OnnxOpConverter): @classmethod def _impl_v1(cls, inputs, attr, params): data = inputs[0] - data_dtype = infer_type(data).checked_type.dtype + data_dtype = infer_type(data).dtype data = _op.exp(data) + _expr.const(1, dtype=data_dtype) return _op.log(data) @@ -2779,7 +2777,7 @@ def _impl_v11(cls, inputs, attr, params): # Create a copy of the body function to prevent the original # from being modified. body = copy.copy(attr["body"]) - iter_dtype = infer_type(max_loop_count).checked_type.dtype + iter_dtype = infer_type(max_loop_count).dtype # Determine what condition mode we're in. assert cond is not None or max_loop_count is not None @@ -2815,10 +2813,6 @@ def cond_fn(*loop_inputs): # Create a list of variables for each value updated in the loop. def get_var(name, val, scan=False): checked_type = infer_type(val) - if hasattr(checked_type, "type_annotation"): - checked_type = checked_type.type_annotation - if hasattr(checked_type, "checked_type"): - checked_type = checked_type.checked_type shape = get_const_tuple(checked_type.shape) actual_shape = [] for dim in shape: @@ -2856,8 +2850,8 @@ def get_var(name, val, scan=False): for i in range(num_scan_outputs): name, _, _, _ = get_info(body.output[i + 1 + num_deps]) output_node = infer_type(loop_outputs[i + 1 + num_deps]) - shape = get_const_tuple(output_node.checked_type.shape) - dtype = output_node.checked_type.dtype + shape = get_const_tuple(output_node.shape) + dtype = output_node.dtype scan_output_vars.append( _expr.var(name, shape=([_ty.Any()] * (len(shape) + 1)), dtype=dtype) ) @@ -3008,7 +3002,7 @@ def _impl_v10(cls, inputs, attr, params): iou_threshold = inputs[3] score_threshold = inputs[4] - boxes_dtype = infer_type(boxes).checked_type.dtype + boxes_dtype = infer_type(boxes).dtype if attr.get("center_point_box", 0) != 0: xc, yc, w, h = _op.split(boxes, 4, axis=2) @@ -3112,13 +3106,13 @@ class QuantizeLinear(OnnxOpConverter): @classmethod def _impl_v10(cls, inputs, attr, params): data, scale, zp = inputs - out_dtype = infer_type(zp).checked_type.dtype + out_dtype = infer_type(zp).dtype return _qnn.op.quantize(data, scale, _op.cast(zp, "int32"), 0, out_dtype) @classmethod def _impl_v13(cls, inputs, attr, params): data, scale, zp = inputs - out_dtype = infer_type(zp).checked_type.dtype + out_dtype = infer_type(zp).dtype axis = attr.get("axis", 1) if len(infer_shape(data)) < 2: axis = 0 @@ -3147,7 +3141,7 @@ class DynamicQuantizeLinear(OnnxOpConverter): def _impl_v11(cls, inputs, attr, params): """This op is deprecated an only supports uint8""" data = inputs[0] - data_dtype = infer_type(data).checked_type.dtype + data_dtype = infer_type(data).dtype zero = _op.const(0, dtype=data_dtype) maximum = _op.maximum(zero, _op.max(data)) minimum = _op.minimum(zero, _op.min(data)) @@ -3189,7 +3183,7 @@ def get_scalar(x, dtype="float32"): ndim = len(input_shape) kernel_type = infer_type(weight) - kernel_shapes = [get_const_tuple(kernel_type.checked_type.shape)] + kernel_shapes = [get_const_tuple(kernel_type.shape)] if "kernel_shape" not in attr: attr["kernel_shape"] = kernel_shapes[0][2:] @@ -3245,7 +3239,7 @@ def get_scalar(x, dtype="float32"): if use_bias: out = _op.nn.bias_add(out, inputs[8]) - out_dtype = infer_type(inputs[7]).checked_type.dtype + out_dtype = infer_type(inputs[7]).dtype requantize_scale = _op.multiply(x_scale, w_scale) # requantize requires y_scale to be constant, @@ -3289,7 +3283,7 @@ def get_scalar(x, dtype="float32"): c_scale = get_scalar(inputs[6]) c_zero_point = get_scalar(inputs[7], "int32") - dtype = infer_type(a).checked_type.dtype + dtype = infer_type(a).dtype ## Onnxruntime doesn't actually do this op in integer, they dequantize to fp32 ## and then requantize afer @@ -3327,7 +3321,7 @@ def get_scalar(x, dtype="float32"): y_scale = fold_constant(get_scalar(inputs[6])) y_zero_point = get_scalar(inputs[7], "int32") - dtype = infer_type(a).checked_type.dtype + dtype = infer_type(a).dtype ## Onnxruntime doesn't actually do this op in integer, they dequantize to fp32 ## and then requantize afer @@ -3353,11 +3347,11 @@ def _impl_v10(cls, inputs, attr, params): weight_zp = _expr.const(0, "int32") input_type = infer_type(data) - input_shape = get_const_tuple(input_type.checked_type.shape) + input_shape = get_const_tuple(input_type.shape) ndim = len(input_shape) kernel_type = infer_type(weight) - kernel_shape = get_const_tuple(kernel_type.checked_type.shape) + kernel_shape = get_const_tuple(kernel_type.shape) if "kernel_shape" not in attr: attr["kernel_shape"] = kernel_shape[2:] @@ -3992,7 +3986,7 @@ class Celu(OnnxOpConverter): @classmethod def _impl_v12(cls, inputs, attr, params): x = inputs[0] - dtype = infer_type(x).checked_type.dtype + dtype = infer_type(x).dtype alpha = _op.const(attr.get("alpha", 1.0), dtype) zero = _op.const(0, dtype) one = _op.const(1, dtype) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 7c10889ce17e..0be229e26a6b 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -26,12 +26,10 @@ import numpy as np import tvm -from tvm.ir import IRModule from tvm.topi.utils import get_const_tuple from .. import analysis as _analysis from .. import expr as _expr -from .. import function as _function from .. import op as _op from .. import qnn, transform from ..expr_functor import ExprMutator @@ -41,7 +39,7 @@ from . import qnn_torch from .common import AttrCvt, get_relay_op, unbind, lstm_cell from .common import infer_value as _infer_value -from .common import infer_shape as _infer_shape +from .common import infer_type, infer_shape from .common import infer_value_simulated as _infer_value_simulated from .common import try_infer_value from .pytorch_utils import is_version_greater_than @@ -148,29 +146,13 @@ def infer_type(self, node, mod=None): if node in self.types: return self.types[node] - if isinstance(node, tvm.relay.Var): - return node.type_annotation - - tf = _TypeFinder(types=self.types) - new_node = tf.visit(node) - fn = _function.Function(list(tf.vars.values()), new_node) - new_mod = IRModule({"main": fn}) - if mod is not None: - new_mod.update(mod) - new_mod = transform.RemoveUnusedFunctions()(new_mod) - new_mod = transform.InferType()(new_mod) - entry = new_mod["main"] - ty = entry.body.checked_type - self.types[node] = ty - return self.types[node] - def infer_type_with_prelude(self, val): - body = self.infer_type(val, self.prelude.mod) - return body + self.types[node] = infer_type(node, mod) + return self.types[node] # list ADT utilities def convert_to_list_adt(self, py_lst): - elem_tys = [self.infer_type_with_prelude(elem) for elem in py_lst] + elem_tys = [self.infer_type(elem, self.prelude.mod) for elem in py_lst] msg = "List elements should have identical types" assert all(map(lambda ty: ty == elem_tys[0], elem_tys)), msg @@ -193,7 +175,7 @@ def convert_to_tensor_array(self, adt_lst): if self.prelude.length(adt_lst) == 0: return nil() - checked_type = self.infer_type_with_prelude(self.prelude.hd(adt_lst)) + checked_type = self.infer_type(self.prelude.hd(adt_lst), self.prelude.mod) shape = checked_type.shape tensor_array = self.map_tensor_array_constructor(adt_lst, shape) return tensor_array, tuple(shape) @@ -213,11 +195,11 @@ def infer_shape_with_prelude(self, inputs): def record_output_type(self, output): if isinstance(output, tuple): cleaned_output = [o for o in output if o is not None] - types = self.infer_type_with_prelude(_expr.Tuple(cleaned_output)) + types = self.infer_type(_expr.Tuple(cleaned_output), self.prelude.mod) for o, t in zip(cleaned_output, types.fields): self.types[o] = t elif isinstance(output, _expr.Expr): - self.infer_type_with_prelude(output) + self.infer_type(output, self.prelude.mod) # it can also happen that the type is int or so def pytorch_promote_types(self, inputs, dtypes): @@ -247,7 +229,7 @@ def is_quantized_tensor(self, data): # If a quantized Torch module is saved and loaded back, dtype will be dropped # Since dtypes from Torch tensors are not reliable in such cases, we use # Relay's type inference result to decide if an input tensor is quantized - ty = self.infer_type_with_prelude(data) + ty = self.infer_type(data, self.prelude.mod) return ty.dtype == "uint8" # Operator implementations @@ -1463,7 +1445,7 @@ def linear(self, inputs, input_types): bias_ndims = len(self.infer_shape_with_prelude(bias)) if bias_ndims == 1: return _op.nn.bias_add(mm_out, bias) - mm_dtype = self.infer_type_with_prelude(mm_out).dtype + mm_dtype = self.infer_type(mm_out, self.prelude.mod).dtype return self.add([mm_out, bias], [mm_dtype, input_types[2]]) return mm_out @@ -1951,7 +1933,7 @@ def stack(self, inputs, input_types): else: # List ADT case assert isinstance(inputs[0], _expr.Expr) - ty = self.infer_type_with_prelude(inputs[0]) + ty = self.infer_type(inputs[0], self.prelude.mod) list_ty = self.prelude.mod.get_global_type_var("List") msg = "The input list is expected to be List ADT" assert isinstance(ty, tvm.ir.TypeCall) and ty.func == list_ty, msg @@ -2451,23 +2433,23 @@ def lstm(self, inputs, input_types): if has_biases: if weights_num == 5: has_proj = True - proj_size = _infer_shape(_weights[4])[0] + proj_size = infer_shape(_weights[4])[0] else: assert weights_num == 4, "The weights number in layer is expected equal to 4" else: if weights_num == 3: has_proj = True - proj_size = _infer_shape(_weights[2])[0] + proj_size = infer_shape(_weights[2])[0] else: assert weights_num == 2, "The weights number in layer is expected equal to 2" X = _op.transpose(_X, (1, 0, 2)) if batch_first else _X # TODO (vvchernov): Which data type should be used? from input or weights? - # Instead of it _infer_type(X).checked_type.dtype can be used + # Instead of it self.infer_type(X).checked_type.dtype can be used X_dtype = input_types[0] - X_shape = _infer_shape(X) # (seq_num, batch, feature_size) + X_shape = infer_shape(X) # (seq_num, batch, feature_size) - hidden_size = _infer_shape(_weights[0])[0] / 4 + hidden_size = infer_shape(_weights[0])[0] / 4 batch_size = X_shape[1] # Initialize hidden states if not provided. @@ -2884,7 +2866,7 @@ def get_input(index): def get_var(name, val): if val: - checked_type = self.infer_type_with_prelude(val) + checked_type = self.infer_type(val, self.prelude.mod) if hasattr(checked_type, "shape"): shape = get_const_tuple(checked_type.shape) actual_shape = [] diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index d35e0e1c203d..c44d17fed952 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -34,8 +34,7 @@ from ..ty import Any from ..expr_functor import ExprMutator, ExprVisitor from .common import get_relay_op -from .common import infer_type as _infer_type -from .common import infer_shape as _infer_shape +from .common import infer_type, infer_shape from .common import infer_value as _infer_value from .tensorflow_ops import _convert_map @@ -344,7 +343,7 @@ def _while_loop(self): # beginning with loop name. if lv not in self._lvar2expr[self._loop_name]: var_name = "{}_loop_var_{}".format(self._loop_name, i) - var_type = _infer_type(lv, self._mod).checked_type + var_type = infer_type(lv, self._mod) loop_var = tvm.relay.var(var_name, type_annotation=var_type) self._lvar2expr[self._loop_name][loop_var] = lv bind_map[lv] = loop_var @@ -951,7 +950,7 @@ def _partition_call_operator(self, inputs, attr): subgraph_shape_dict, input_expr_dict = {}, {} for f_arg, input in zip(func.signature.input_arg, inputs): input_expr_dict[f_arg.name] = input - subgraph_shape_dict[f_arg.name] = _infer_shape(input, main_graph_proto._mod) + subgraph_shape_dict[f_arg.name] = infer_shape(input, main_graph_proto._mod) func_name = "func_{}".format(func.signature.name) try: @@ -1078,7 +1077,7 @@ def _licm_construct(self, loop_name, node_name): if node_name not in self._lname_map[loop_name]: var_name = "{}_loop_var".format(node_name) - var_type = _infer_type(actual_expr, self._mod).checked_type + var_type = infer_type(actual_expr, self._mod) loop_var = tvm.relay.var(var_name, type_annotation=var_type) try: extra_param = _infer_value(actual_expr, self._params, self._mod) @@ -1158,7 +1157,7 @@ def _backtrack_construct(self, node_name): if output_index > 0: name += ":" + str(output_index) converted = self._backtrack_construct(name) - shape = _infer_shape(converted, self._mod) + shape = infer_shape(converted, self._mod) if wnode_op.startswith("TensorArraySplit"): shape = (Any(),) + shape[1:] elif wnode_op.startswith("TensorArrayScatter"): diff --git a/python/tvm/relay/frontend/tensorflow2.py b/python/tvm/relay/frontend/tensorflow2.py index 465f530624b9..013e7efa438d 100644 --- a/python/tvm/relay/frontend/tensorflow2.py +++ b/python/tvm/relay/frontend/tensorflow2.py @@ -35,7 +35,7 @@ from .. import analysis from .. import function as _function from ..loops import while_loop as _while_loop -from .common import infer_type as _infer_type +from .common import infer_type from .tensorflow_ops import _convert_map as _convert_map_common from .tensorflow_ops import _get_more_static_shape_rank @@ -53,11 +53,6 @@ } -def _infer_type_with_prelude(val, prelude): - body = _infer_type(val, prelude.mod) - return body.checked_type - - def set_span(sym, node_name): """set span of symbol""" @@ -639,10 +634,10 @@ def convert_vars(loop_inputs, input_signature): new_vars = [] for i, v in enumerate(loop_inputs): if isinstance(v, _expr.Constant): - vtype = _infer_type(v).checked_type.dtype + vtype = infer_type(v).dtype new_vars.append(_expr.var(input_signature[i].name, shape=(), dtype=vtype)) else: - vtype = _infer_type_with_prelude(v, prelude) + vtype = infer_type(v, prelude.mod) new_vars.append(_expr.var(input_signature[i].name, type_annotation=vtype)) return new_vars @@ -744,7 +739,7 @@ def _convert_function( input_types = {} for f_arg, input_ in zip(func.signature.input_arg, inputs): input_expr_dict[f_arg.name] = input_ - input_types[f_arg.name] = _infer_type_with_prelude(input_, prelude) + input_types[f_arg.name] = infer_type(input_, prelude.mod) func_name = "func_{}".format(func.signature.name) try: diff --git a/python/tvm/relay/frontend/tensorflow2_ops.py b/python/tvm/relay/frontend/tensorflow2_ops.py index 17cd112878a5..62ce865a0e17 100644 --- a/python/tvm/relay/frontend/tensorflow2_ops.py +++ b/python/tvm/relay/frontend/tensorflow2_ops.py @@ -22,15 +22,10 @@ from .. import op as _op from ..ty import Any from .common import infer_value as _infer_value -from .common import infer_type as _infer_type +from .common import infer_type from .tensorflow_ops import _get_more_static_shape_rank -def _infer_type_with_prelude(val, prelude): - body = _infer_type(val, prelude.mod) - return body.checked_type - - def _need_prelude_for_shape_inference(op): return "TensorList" in op or "TensorArray" in op @@ -60,7 +55,7 @@ def _impl(inputs, attr, params, prelude): dtype_str = attr.get("element_dtype").name input_ta = inputs[0] input_ta_shape = get_tensor_array_shape(input_ta, dtype_str, prelude) - input_t_shape = _infer_type_with_prelude(inputs[2], prelude).shape + input_t_shape = infer_type(inputs[2], prelude.mod).shape input_rank = len(input_t_shape) if input_ta_shape is None: @@ -161,7 +156,7 @@ def _impl(inputs, attr, params, prelude): def _tensorlist_from_tensor(): def _impl(inputs, attr, params, prelude): dtype_str = attr["element_dtype"].name - input_ta_shape = _infer_type_with_prelude(inputs[0], prelude).shape + input_ta_shape = infer_type(inputs[0], prelude.mod).shape if input_ta_shape is None: unstack_func = prelude.get_global_var("tensor_array_unstack", dtype_str) diff --git a/python/tvm/relay/frontend/tensorflow_ops.py b/python/tvm/relay/frontend/tensorflow_ops.py index a8213d4b1c49..688162e46214 100644 --- a/python/tvm/relay/frontend/tensorflow_ops.py +++ b/python/tvm/relay/frontend/tensorflow_ops.py @@ -31,8 +31,7 @@ from .. import op as _op from ..ty import Any from .common import AttrCvt, get_relay_op -from .common import infer_type as _infer_type -from .common import infer_shape as _infer_shape +from .common import infer_type, infer_shape from .common import infer_channels as _infer_channels from .common import infer_value as _infer_value @@ -192,7 +191,7 @@ def _impl(inputs, attr, params, mod): attr["data_format"] = attr["data_format"].decode("utf-8") flip_layout = False - input_shape = _infer_shape(inputs[0], mod) + input_shape = infer_shape(inputs[0], mod) if attr["data_format"] == "NDHWC": attr["kernel_shape"] = (attr["ksize"][1], attr["ksize"][2], attr["ksize"][3]) @@ -204,7 +203,7 @@ def _impl(inputs, attr, params, mod): msg = 'Value {} of attribute "data_format" of operator Pooling ' "is not valid." raise tvm.error.OpAttributeInvalid(msg.format(attr["data_format"])) if attr["data_format"] == "NDHWC": - input_shape = [_infer_shape(inputs[0], mod)[i] for i in (0, 4, 1, 2, 3)] + input_shape = [infer_shape(inputs[0], mod)[i] for i in (0, 4, 1, 2, 3)] inputs[0] = _op.transpose(inputs[0], axes=(0, 4, 1, 2, 3)) attr["data_format"] = "NCDHW" flip_layout = True @@ -254,7 +253,7 @@ def _impl(inputs, attr, params, mod): attr["data_format"] = attr["data_format"].decode("utf-8") flip_layout = False - input_shape = _infer_shape(inputs[0], mod) + input_shape = infer_shape(inputs[0], mod) if attr["data_format"] == "NHWC": attr["kernel_shape"] = (attr["ksize"][1], attr["ksize"][2]) @@ -267,7 +266,7 @@ def _impl(inputs, attr, params, mod): raise tvm.error.OpAttributeInvalid(msg.format(attr["data_format"])) if attr["_target_layout"] == "NCHW" and attr["data_format"] == "NHWC": - tmp_shape = _infer_shape(inputs[0], mod) + tmp_shape = infer_shape(inputs[0], mod) input_shape = [tmp_shape[ii] for ii in (0, 3, 1, 2)] inputs[0] = _op.transpose(inputs[0], axes=(0, 3, 1, 2)) attr["data_format"] = "NCHW" @@ -353,7 +352,7 @@ def _impl(inputs, attr, params, mod): inputs_data = inputs[0] if opname != "conv_transpose" else inputs[2] # NCHW Layout require weights transpose - weights_shape = _infer_shape(inputs[1], mod) + weights_shape = infer_shape(inputs[1], mod) if attr["data_format"] == "NCHW": tmp_shape = weights_shape if opname in ["conv", "conv_transpose"]: @@ -364,7 +363,7 @@ def _impl(inputs, attr, params, mod): inputs[1] = _op.transpose(inputs[1], axes=(2, 3, 0, 1)) weights_shape = tmp_shape - input_shape = _infer_shape(inputs_data, mod) + input_shape = infer_shape(inputs_data, mod) if attr["_target_layout"] == "NCHW" and attr["data_format"] == "NHWC": input_shape = [input_shape[ii] for ii in (0, 3, 1, 2)] inputs_data = _op.transpose(inputs_data, axes=(0, 3, 1, 2)) @@ -495,8 +494,8 @@ def _impl(inputs, attr, params, mod): if "data_format" not in attr: attr["data_format"] = "NHWC" - input_shape = _infer_shape(inputs[0], mod) - weights_shape = _infer_shape(inputs[1], mod) + input_shape = infer_shape(inputs[0], mod) + weights_shape = infer_shape(inputs[1], mod) if attr["_target_layout"] == "NCHW" and attr["data_format"] == "NHWC": input_shape = [input_shape[ii] for ii in (0, 3, 1, 2)] @@ -578,14 +577,14 @@ def _impl(inputs, attr, params, mod): inputs_data = inputs[0] if opname != "conv_transpose" else inputs[2] # NCDHW Layout require weights transpose - weights_shape = _infer_shape(inputs[1], mod) + weights_shape = infer_shape(inputs[1], mod) if attr["data_format"] == "NCDHW": tmp_shape = weights_shape tmp_shape = [tmp_shape[ii] for ii in (4, 3, 0, 1, 2)] inputs[1] = _op.transpose(inputs[1], axes=(4, 3, 0, 1, 2)) weights_shape = tmp_shape - input_shape = _infer_shape(inputs_data, mod) + input_shape = infer_shape(inputs_data, mod) if attr["_target_layout"] == "NCDHW" and attr["data_format"] == "NDHWC": input_shape = [input_shape[ii] for ii in (0, 4, 1, 2, 3)] @@ -900,8 +899,8 @@ def _impl(inputs, attr, params, mod): raise tvm.error.OpAttributeUnImplemented( "pad_per_class for CombinedNonMaxSuppression is not supported" ) - boxes_shape = _infer_shape(inputs[0], mod) - scores_shape = _infer_shape(inputs[1], mod) + boxes_shape = infer_shape(inputs[0], mod) + scores_shape = infer_shape(inputs[1], mod) batch_size = boxes_shape[0] num_anchors = boxes_shape[1] q = boxes_shape[2] @@ -1153,8 +1152,8 @@ def _impl(inputs, attr, params, mod): input_x = inputs[0] input_y = inputs[1] - orig_shape_x = _infer_shape(input_x, mod) - orig_shape_y = _infer_shape(input_y, mod) + orig_shape_x = infer_shape(input_x, mod) + orig_shape_y = infer_shape(input_y, mod) ndim = len(orig_shape_x) ndim_y = len(orig_shape_y) @@ -1308,7 +1307,7 @@ def _impl(inputs, attr, params, mod): assert len(inputs) == 4, "There should be 4 input tensors" sparse_indices = inputs[0] sparse_values = inputs[1] - sparse_indices_num_cols = _infer_shape(sparse_indices, mod)[1] + sparse_indices_num_cols = infer_shape(sparse_indices, mod)[1] first_column = _op.split(sparse_indices, sparse_indices_num_cols, axis=1)[0] sorted_indices = _op.argsort(_op.squeeze(first_column)) sorted_sparse_indices = _op.take(sparse_indices, sorted_indices, axis=0) @@ -1560,9 +1559,9 @@ def _impl(inputs, attr, params, prelude): dtype_str = attr.get("T").name input_ta = inputs[0] input_shape = get_tensor_array_shape(input_ta, dtype_str, prelude) - values_shape = _infer_shape(inputs[2], prelude.mod) + values_shape = infer_shape(inputs[2], prelude.mod) input_t_shape = values_shape[1:] - indices_shape = _infer_shape(inputs[1], prelude.mod) + indices_shape = infer_shape(inputs[1], prelude.mod) if input_shape is None: values_rank = len(values_shape) @@ -1598,7 +1597,7 @@ def _tensor_array_gather(): def _impl(inputs, attr, params, prelude): dtype_str = attr.get("dtype").name input_shape = get_tensor_array_shape(inputs[2], dtype_str, prelude) - indices_shape = _infer_shape(inputs[1], prelude.mod) + indices_shape = infer_shape(inputs[1], prelude.mod) if input_shape is None: gather_func = prelude.get_var("tensor_array_gather", dtype_str) @@ -1657,7 +1656,7 @@ def _impl(inputs, attr, params, prelude): dtype_str = attr.get("T").name input_ta = inputs[3] input_ta_shape = get_tensor_array_shape(input_ta, dtype_str, prelude) - input_t_shape = _infer_shape(inputs[2], prelude.mod) + input_t_shape = infer_shape(inputs[2], prelude.mod) input_rank = len(input_t_shape) if input_ta_shape is None: @@ -1722,8 +1721,8 @@ def _impl(inputs, attr, params, prelude): input_ta = inputs[0] input_ta_shape = get_tensor_array_shape(input_ta, dtype_str, prelude) lengths = _op.cast(inputs[2], "int32") - lengths_shape = _infer_shape(lengths, prelude.mod) - value_shape = _infer_shape(inputs[1], prelude.mod) + lengths_shape = infer_shape(lengths, prelude.mod) + value_shape = infer_shape(inputs[1], prelude.mod) input_rank = len(value_shape) if input_ta_shape is None: @@ -1813,7 +1812,7 @@ def _impl(inputs, attr, params, mod): size = inputs[2] # Align begin and strides for dynamic shape. - data_dim = len(_infer_shape(inputs[0], mod)) + data_dim = len(infer_shape(inputs[0], mod)) strides = [1] * data_dim if not isinstance(begin, (_expr.Call, _expr.Var)): for _ in range(len(begin), data_dim): @@ -2030,7 +2029,7 @@ def _impl(inputs, attr, params, mod): def _shape(): def _impl(inputs, attr, params, mod): is_symbolic_shape = False - input_shape = _infer_shape(inputs[0], mod) + input_shape = infer_shape(inputs[0], mod) for axis in input_shape: if not isinstance(axis, (int, tvm.tir.IntImm)): is_symbolic_shape = True @@ -2148,7 +2147,7 @@ def _gather_nd(): """GatherNd""" def _impl(inputs, attr, params, mod): - indices_dims = len(_infer_shape(inputs[1], mod)) + indices_dims = len(infer_shape(inputs[1], mod)) indices = _op.transpose(inputs[1], axes=[-1] + list(range(indices_dims - 1))) return AttrCvt(op_name="gather_nd", ignores=["Tindices", "Tparams", "Taxis", "_class"])( [inputs[0], indices], attr @@ -2173,8 +2172,8 @@ def _impl(inputs, attr, params, mod): ellipsis_mask = int(attr.get("ellipsis_mask", 0)) new_axis_mask = int(attr.get("new_axis_mask", 0)) shrink_axis_mask = int(attr.get("shrink_axis_mask", 0)) - in_type = _infer_type(inputs[0], mod) - data_shape = get_const_tuple(in_type.checked_type.shape) + in_type = infer_type(inputs[0], mod) + data_shape = get_const_tuple(in_type.shape) data_dim = len(data_shape) stride_dim = len(stride) if data_dim == 0 and isinstance(inputs[0], _expr.Constant): @@ -2194,8 +2193,8 @@ def _impl(inputs, attr, params, mod): if ed <= 0 < st: ed += data_shape[0] - in_shape = _infer_shape(inputs[0].args[0], mod) - dtype = in_type.checked_type.dtype + in_shape = infer_shape(inputs[0].args[0], mod) + dtype = in_type.dtype out_data = [] idx = bg while idx < ed: @@ -2284,7 +2283,7 @@ def _transform_mask(stride_dim, ellipsis_mask): if begin_mask or end_mask or ellipsis_mask or new_axis_mask or shrink_axis_mask: begin, end, stride, fshape_indices = _transform_mask(stride_dim, ellipsis_mask) out = _op.strided_slice(inputs[0], begin=begin, end=end, strides=stride) - out_shape = _infer_shape(out, mod=mod) + out_shape = infer_shape(out, mod=mod) if not fshape_indices: fshape_indices = range(len(out_shape)) @@ -2404,7 +2403,7 @@ def _impl(inputs, attr, params, mod): def _rank(): def _impl(inputs, attr, params, mod): - input_shape = _infer_shape(inputs[0], mod) + input_shape = infer_shape(inputs[0], mod) name = attr["_node_name"] params[name] = tvm.nd.array(np.array([len(input_shape)]).astype("int32")) @@ -2561,7 +2560,7 @@ def _unpack(): def _impl(inputs, attr, params, mod): input_node = inputs[0] axis = attr["axis"] - input_shape = _infer_shape(input_node, mod) + input_shape = infer_shape(input_node, mod) axis_length = input_shape[axis] if axis_length < 0: raise TypeError("Unstack with unknown axis length") @@ -2801,8 +2800,8 @@ def _impl(inputs, attr, params, mod): in_weight = inputs[3] in_bias = inputs[7] forget_bias = attr.pop("forget_bias") - input_shape = _infer_shape(inputs[0], mod) - weight_shape = _infer_shape(inputs[3], mod) + input_shape = infer_shape(inputs[0], mod) + weight_shape = infer_shape(inputs[3], mod) batch_size, input_size = input_shape[0], input_shape[1] num_hidden_layers = weight_shape[1] diff --git a/tests/python/relay/test_const.py b/tests/python/relay/test_const.py index c815f6bd4fa4..7ba0463cc059 100644 --- a/tests/python/relay/test_const.py +++ b/tests/python/relay/test_const.py @@ -28,7 +28,7 @@ def test_const_dtype(): strides = _op.const(np_array, dtype="int64") # strides needs to be autoconverted to int64 on Windows - assert infer_type(strides).checked_type.dtype == np.dtype(np.int64) + assert infer_type(strides).dtype == np.dtype(np.int64) a = tvm.nd.array(np.random.randint(0, high=255, size=(2, 3), dtype="uint8")) a = _op.const(a, dtype="uint8")