diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 205b2aa779e6..931611274c20 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -385,23 +385,28 @@ def tensor_array_concat(lst, axis): def slice(self, inputs, input_types): axis_dtype = "int64" - index_size_limit = 2 ** 63 - 1 + index_size_limit = sys.maxsize data = inputs[0] dshape = self.infer_shape(data) ndim = len(dshape) - end = [] - for dim in dshape: - if isinstance(dim, tvm.tir.Any): - end = _op.shape_of(data) - break - end.append(int(dim)) - - begin = [0] * ndim dim = int(inputs[1]) - stride = int(inputs[4]) - begin[dim], _ = try_infer_value(inputs[2], lambda ret: np.asscalar(ret.astype(np.int))) + stride = inputs[4] + + target_begin, is_begin_const = try_infer_value( + inputs[2], lambda ret: np.asscalar(ret.astype(np.int)) + ) + target_end, is_end_const = try_infer_value( + inputs[3], lambda ret: np.asscalar(ret.astype(np.int)) + ) + + # A fast path when slicing is nop. + if target_begin == 0 and target_end >= index_size_limit and stride == 1: + return data # Process begin + begin = [0] * ndim + begin[dim] = target_begin + if not isinstance(begin[dim], int): tmp = [] for b in begin: @@ -414,27 +419,15 @@ def slice(self, inputs, input_types): if str(btype) != axis_dtype: begin = _op.cast(begin, axis_dtype) - if isinstance(inputs[3], str) and inputs[3].isdigit(): - target_end = int(inputs[3]) + # Process end + if isinstance(target_end, int) and target_end >= index_size_limit: + target_end = dshape[dim] + + if any([isinstance(d, tvm.tir.Any) for d in dshape]): + end = _op.shape_of(data) else: - if isinstance(inputs[3], _expr.Expr): - target_end, _ = try_infer_value( - inputs[3], lambda ret: np.asscalar(ret.astype(np.int)) - ) - else: - target_end = inputs[3] - - if isinstance(target_end, int) and target_end >= index_size_limit: - # Quick path for original data. - if ( - isinstance(begin, _expr.Constant) - and begin.data.asnumpy().tolist()[dim] == 0 - and stride == 1 - ): - return data - target_end = dshape[dim] + end = dshape - # Process end if isinstance(target_end, int): if isinstance(end, list): end[dim] = target_end @@ -474,7 +467,7 @@ def slice(self, inputs, input_types): end = _op.cast(end, axis_dtype) strides = [1] * ndim - strides[dim] = int(inputs[4]) + strides[dim] = stride return _op.transform.strided_slice( data, begin=begin, end=end, strides=strides, slice_mode="end" diff --git a/python/tvm/relay/frontend/pytorch_utils.py b/python/tvm/relay/frontend/pytorch_utils.py index 248f5354cfbb..02b2484d4fb7 100644 --- a/python/tvm/relay/frontend/pytorch_utils.py +++ b/python/tvm/relay/frontend/pytorch_utils.py @@ -97,15 +97,11 @@ def batched_nms(boxes, scores, idxs, iou_threshold): add = is_op("add")(mx, one) mul = is_op("multiply")(cast, add) - # The following doesn't appear in the above Relay snippet. It is required for dynamic - # stride_slice handling shape_of = is_op("shape_of")(mul) cast = is_op("cast")(shape_of) - # This corresponds to offsets[:, None], where offsets is the result of multiplication - dyn_strided_slice = dyn_strided_slice_pattern(mul, cast) # Add offsets to the boxes - expand_dims = is_op("expand_dims")(dyn_strided_slice) + expand_dims = is_op("expand_dims")(mul) add = is_op("add")(boxes, expand_dims) # The rest of patterns correspond to the PyTorch frontend conversion