From 061fbc95eccca144d0b4ac7065cbbda323f6e1b2 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 16 Feb 2021 23:44:55 +0900 Subject: [PATCH 1/3] simplyfing --- python/tvm/relay/frontend/pytorch.py | 51 +++++++++++----------------- 1 file changed, 19 insertions(+), 32 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 205b2aa779e6..58d58cff1552 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -385,23 +385,20 @@ def tensor_array_concat(lst, axis): def slice(self, inputs, input_types): axis_dtype = "int64" - index_size_limit = 2 ** 63 - 1 + index_size_limit = sys.maxsize data = inputs[0] dshape = self.infer_shape(data) ndim = len(dshape) - end = [] - for dim in dshape: - if isinstance(dim, tvm.tir.Any): - end = _op.shape_of(data) - break - end.append(int(dim)) - - begin = [0] * ndim dim = int(inputs[1]) - stride = int(inputs[4]) - begin[dim], _ = try_infer_value(inputs[2], lambda ret: np.asscalar(ret.astype(np.int))) + stride = inputs[4] + + target_begin, is_begin_const = try_infer_value(inputs[2], lambda ret: np.asscalar(ret.astype(np.int))) + target_end, is_end_const = try_infer_value(inputs[3], lambda ret: np.asscalar(ret.astype(np.int))) # Process begin + begin = [0] * ndim + begin[dim] = target_begin + if not isinstance(begin[dim], int): tmp = [] for b in begin: @@ -414,27 +411,17 @@ def slice(self, inputs, input_types): if str(btype) != axis_dtype: begin = _op.cast(begin, axis_dtype) - if isinstance(inputs[3], str) and inputs[3].isdigit(): - target_end = int(inputs[3]) - else: - if isinstance(inputs[3], _expr.Expr): - target_end, _ = try_infer_value( - inputs[3], lambda ret: np.asscalar(ret.astype(np.int)) - ) - else: - target_end = inputs[3] - - if isinstance(target_end, int) and target_end >= index_size_limit: - # Quick path for original data. - if ( - isinstance(begin, _expr.Constant) - and begin.data.asnumpy().tolist()[dim] == 0 - and stride == 1 - ): - return data - target_end = dshape[dim] - # Process end + if isinstance(target_end, int) and target_end >= index_size_limit: + target_end = dshape[dim] + + end = [] + for d in dshape: + if isinstance(d, tvm.tir.Any): + end = _op.shape_of(data) + break + end.append(int(d)) + if isinstance(target_end, int): if isinstance(end, list): end[dim] = target_end @@ -474,7 +461,7 @@ def slice(self, inputs, input_types): end = _op.cast(end, axis_dtype) strides = [1] * ndim - strides[dim] = int(inputs[4]) + strides[dim] = stride return _op.transform.strided_slice( data, begin=begin, end=end, strides=strides, slice_mode="end" From 58e029c720fc185556f888746a5d24d8e086655d Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 17 Feb 2021 00:09:25 +0900 Subject: [PATCH 2/3] improved fast path for slice --- python/tvm/relay/frontend/pytorch.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 58d58cff1552..931611274c20 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -392,8 +392,16 @@ def slice(self, inputs, input_types): dim = int(inputs[1]) stride = inputs[4] - target_begin, is_begin_const = try_infer_value(inputs[2], lambda ret: np.asscalar(ret.astype(np.int))) - target_end, is_end_const = try_infer_value(inputs[3], lambda ret: np.asscalar(ret.astype(np.int))) + target_begin, is_begin_const = try_infer_value( + inputs[2], lambda ret: np.asscalar(ret.astype(np.int)) + ) + target_end, is_end_const = try_infer_value( + inputs[3], lambda ret: np.asscalar(ret.astype(np.int)) + ) + + # A fast path when slicing is nop. + if target_begin == 0 and target_end >= index_size_limit and stride == 1: + return data # Process begin begin = [0] * ndim @@ -415,12 +423,10 @@ def slice(self, inputs, input_types): if isinstance(target_end, int) and target_end >= index_size_limit: target_end = dshape[dim] - end = [] - for d in dshape: - if isinstance(d, tvm.tir.Any): - end = _op.shape_of(data) - break - end.append(int(d)) + if any([isinstance(d, tvm.tir.Any) for d in dshape]): + end = _op.shape_of(data) + else: + end = dshape if isinstance(target_end, int): if isinstance(end, list): From 58d1c66a57a3e93b6d8889ef846d9bc8bf7d0b3a Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 17 Feb 2021 09:30:46 +0900 Subject: [PATCH 3/3] update rewrite pattern for maskrcnn --- python/tvm/relay/frontend/pytorch_utils.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch_utils.py b/python/tvm/relay/frontend/pytorch_utils.py index 248f5354cfbb..02b2484d4fb7 100644 --- a/python/tvm/relay/frontend/pytorch_utils.py +++ b/python/tvm/relay/frontend/pytorch_utils.py @@ -97,15 +97,11 @@ def batched_nms(boxes, scores, idxs, iou_threshold): add = is_op("add")(mx, one) mul = is_op("multiply")(cast, add) - # The following doesn't appear in the above Relay snippet. It is required for dynamic - # stride_slice handling shape_of = is_op("shape_of")(mul) cast = is_op("cast")(shape_of) - # This corresponds to offsets[:, None], where offsets is the result of multiplication - dyn_strided_slice = dyn_strided_slice_pattern(mul, cast) # Add offsets to the boxes - expand_dims = is_op("expand_dims")(dyn_strided_slice) + expand_dims = is_op("expand_dims")(mul) add = is_op("add")(boxes, expand_dims) # The rest of patterns correspond to the PyTorch frontend conversion