Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -210,14 +210,24 @@ struct SplitAttrs : public tvm::AttrsNode<SplitAttrs> {

/*! \brief Attributes for StridedSlice operator */
struct StridedSliceAttrs : public tvm::AttrsNode<StridedSliceAttrs> {
Array<Integer> begin;
Array<Integer> end;
Array<Integer> strides;
Optional<Array<Integer>> begin;
Optional<Array<Integer>> end;
Optional<Array<Integer>> strides;
std::string slice_mode;

TVM_DECLARE_ATTRS(StridedSliceAttrs, "relay.attrs.StridedSliceAttrs") {
TVM_ATTR_FIELD(begin).describe("Indices for begin of slice, begin index is also inclusive");
TVM_ATTR_FIELD(end).describe("Indices for end of slice, end index is exclusive");
TVM_ATTR_FIELD(strides).set_default(Array<Integer>({})).describe("Stride values of the slice");
TVM_ATTR_FIELD(strides).describe(
"Stride values of the slice, a stride can be negative, which causes a reverse slice.");
TVM_ATTR_FIELD(slice_mode)
.set_default("end")
.describe(
"The slice mode [end, size]."
"end - The default slice mode, ending indices for the slice."
"size - The input strides will be ignored, input end in this mode indicates the size"
"of a slice starting at the location specified by begin. If end[i] is -1,"
"all remaining elements in that dimension are included in the slice");
}
};

Expand Down
4 changes: 3 additions & 1 deletion include/tvm/relay/attrs/vision.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,9 @@ struct NonMaximumSuppressionAttrs : public tvm::AttrsNode<NonMaximumSuppressionA
.describe(
"Max number of output valid boxes for each instance."
"By default all valid boxes are returned.");
TVM_ATTR_FIELD(iou_threshold).set_default(0.5).describe("Non-maximum suppression threshold.");
TVM_ATTR_FIELD(iou_threshold)
.set_default(0.5)
.describe("Non-maximum suppression iou threshold.");
TVM_ATTR_FIELD(force_suppress)
.set_default(false)
.describe("Suppress all detections regardless of class_id.");
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def convert(self, v):
def __call__(self, args, attrs, type_args):
if attrs is None:
attrs = {}
if self.operator is op.reshape:
if self.operator in (op.reshape, op.strided_slice):
x = self.operator(*args)
elif self.operator in (op.zeros, op.ones, op.full, op.broadcast_to):
x = self.operator(*args, dtype=attrs["dtype"])
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/relay/frontend/keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,8 +611,8 @@ def _convert_cropping(inexpr, keras_layer, _):
raise tvm.error.OpNotImplemented(
'Operator {} is not supported for frontend Keras.'.format(crop_type))
int32_max = np.iinfo(np.int32).max
return _op.strided_slice(inexpr, begin=[0, 0, crop_t, crop_l], \
end=[int32_max, int32_max, in_h-crop_b, in_w-crop_r])
return _op.strided_slice(inexpr, begin=_expr.const([0, 0, crop_t, crop_l]), \
end=_expr.const([int32_max, int32_max, in_h-crop_b, in_w-crop_r]))


def _convert_batchnorm(inexpr, keras_layer, etab):
Expand Down
25 changes: 17 additions & 8 deletions python/tvm/relay/frontend/mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,16 +411,22 @@ def _mx_slice(inputs, attrs):
begin = list(attrs.get_int_tuple('begin', None))
end = list(attrs.get_int_tuple('end', None))
stride = attrs.get_int_tuple('step', None)
input_shape = _infer_type(inputs[0]).checked_type.shape
if begin is None:
raise tvm.error.OpAttributeRequired(
'Attribute "begin" not found in operator Slice.')
if end is None:
raise tvm.error.OpAttributeRequired(
'Attribute "end" not found in operator Slice.')
begin = tuple(x if x is not None else 0 for x in begin)
new_attrs = {'begin': begin, 'end': end}
begin = (x if x is not None else 0 for x in begin)
for i, ed in enumerate(end):
if ed is None:
end[i] = input_shape[i]
new_attrs = {'begin': _expr.const(list(begin), dtype="int32"),
'end': _expr.const(list(end), dtype="int32")}
if stride is not None:
new_attrs['strides'] = stride
stride = (x if x is not None else 1 for x in stride)
new_attrs['strides'] = _expr.const(list(stride), dtype="int32")
return _op.strided_slice(inputs[0], **new_attrs)


Expand Down Expand Up @@ -460,7 +466,9 @@ def _mx_slice_axis(inputs, attrs):
else:
begin.append(ax_beg)
end.append(ax_end)
return _op.strided_slice(inputs[0], begin, end)
return _op.strided_slice(inputs[0],
_expr.const(begin, dtype="int32"),
_expr.const(end, dtype="int32"))


def _mx_crop_like(inputs, attrs):
Expand All @@ -480,9 +488,9 @@ def _mx_crop_like(inputs, attrs):
return _op.slice_like(*inputs, **new_attrs)
expr = _infer_type(inputs[1])
like_shape = expr.checked_type.shape
new_attrs['begin'] = [0, 0, offset[0], offset[1]]
new_attrs['end'] = [like_shape[0], like_shape[1], offset[0]+like_shape[2],
offset[1]+like_shape[3]]
new_attrs['begin'] = _expr.const([0, 0, offset[0], offset[1]], dtype="int32")
new_attrs['end'] = _expr.const([like_shape[0], like_shape[1], offset[0]+like_shape[2],
offset[1]+like_shape[3]], dtype="int32")
return _op.strided_slice(inputs[0], **new_attrs)


Expand Down Expand Up @@ -656,7 +664,7 @@ def _mx_multibox_detection(inputs, attrs):

ret = _op.vision.multibox_transform_loc(inputs[0], inputs[1],
inputs[2], **new_attrs0)
return _op.vision.non_max_suppression(ret[0], ret[1], **new_attrs1)
return _op.vision.non_max_suppression(ret[0], ret[1], ret[1], **new_attrs1)


def _mx_batch_dot(inputs, attrs):
Expand Down Expand Up @@ -820,6 +828,7 @@ def _mx_box_nms(inputs, attrs):
id_index=id_index, score_index=score_index)
nms_out = _op.vision.non_max_suppression(ret[1],
ret[0],
ret[2],
iou_threshold=iou_thresh,
force_suppress=force_suppress,
top_k=top_k,
Expand Down
13 changes: 8 additions & 5 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -945,11 +945,12 @@ def _impl_v1(cls, inputs, attr, params):
attr['ends'] = new_ends
except KeyError:
pass
begin = list(attr['starts'])
end = list(attr['ends'])

return AttrCvt('strided_slice',
transforms={'starts': 'begin',
'ends': 'end'},
ignores=['axes'])(inputs, attr)
return _op.strided_slice(inputs[0],
begin=_expr.const(begin, dtype="int32"),
end=_expr.const(end, dtype="int32"))

@classmethod
def _impl_v10(cls, inputs, attr, params):
Expand All @@ -965,7 +966,9 @@ def _impl_v10(cls, inputs, attr, params):
starts, ends, axes)
starts = new_starts
ends = new_ends
return _op.strided_slice(inputs[0], begin=starts, end=ends)
return _op.strided_slice(inputs[0],
begin=_expr.const(starts, dtype="int32"),
end=_expr.const(ends, dtype="int32"))


class Gather(OnnxOpConverter):
Expand Down
16 changes: 13 additions & 3 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,11 @@ def _impl(inputs, input_types):
end[dim] = inputs[3]

strides.append(int(inputs[4]))
return _op.transform.strided_slice(data, begin, end, strides)
return _op.transform.strided_slice(data,
begin=_expr.const(begin),
end=_expr.const(end),
strides=_expr.const(strides),
slice_mode="size")
return _impl

def _split():
Expand Down Expand Up @@ -1233,7 +1237,10 @@ def _impl(inputs, input_types):
end[axis] = i + unif_size
stride = [1] * len(shape)

chunk_out = _op.transform.strided_slice(data, begin, end, stride)
chunk_out = _op.transform.strided_slice(data,
begin=_expr.const(begin),
end=_expr.const(end),
strides=_expr.const(stride))
chunks.append(chunk_out)

if dim % num_chunks:
Expand All @@ -1243,7 +1250,10 @@ def _impl(inputs, input_types):
end[axis] = dim
stride = [1] * len(shape)

chunk_out = _op.transform.strided_slice(data, begin, end, stride)
chunk_out = _op.transform.strided_slice(data,
begin=_expr.const(begin),
end=_expr.const(end),
strides=_expr.const(stride))
chunks.append(chunk_out)

return chunks
Expand Down
84 changes: 70 additions & 14 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,6 +614,62 @@ def _impl(inputs, attr, params, mod):
return out
return _impl

def _nms():
def _impl(inputs, attr, params, mod):
# Get parameter values
# TODO(yongwww) change nms in relay to support symbolic max_output_size
try:
max_output_size = int(np.atleast_1d(inputs[2].data.asnumpy()
.astype("int64"))[0])
except Exception:
try:
max_output_size = _infer_value(inputs[2], params,
mod).asnumpy().astype("int64").tolist()[0]
except Exception:
max_output_size = -1
iou_threshold = np.atleast_1d(inputs[3].data.asnumpy())[0]
# score_threshold was introduced from V3
score_threshold = np.atleast_1d(inputs[4].data.asnumpy())[0] if len(inputs) > 4 else 0.0

# Generate data with shape (1, num_anchors, 5)
scores = AttrCvt(op_name="expand_dims",
ignores=['T_threshold'],
extras={'axis': -1, 'num_newaxis': 1})([inputs[1]], attr)
data = get_relay_op('concatenate')([scores, inputs[0]], -1)
data = get_relay_op('expand_dims')(data, 0, 1)

# reason why using get_valid_counts is for inference performance
ct, data, indices = get_relay_op('get_valid_counts')(data,
score_threshold=score_threshold,
id_index=-1,
score_index=0)
# TensorFlow NMS doesn't have parameter top_k
top_k = -1
# TF doesn't have class id for nms input
score_index = 0
nms_ret = get_relay_op('non_max_suppression')(data=data,
valid_count=ct,
indices=indices,
max_output_size=max_output_size,
iou_threshold=iou_threshold,
force_suppress=True,
top_k=top_k,
coord_start=1,
score_index=score_index,
id_index=-1,
return_indices=True,
invalid_to_bottom=False)

# squeeze it, TF NMS is not batched
size = get_relay_op("squeeze")(nms_ret[1], axis=[1])
data_slice = get_relay_op("squeeze")(nms_ret[0], axis=[0])

# slice to get the dynamic result
ret = get_relay_op("strided_slice")(data_slice, begin=_expr.const([0]),
end=size, slice_mode="size")
return ret
return _impl

def _decode_image():
def _impl(inputs, attr, params, mod):
# Image decode wrapper: Expecting user to feed decoded input to next layer drop this layer.
Expand Down Expand Up @@ -1119,25 +1175,20 @@ def _impl(inputs, attr, params, mod):
try:
begin = _get_list_param(params, inputs[1])
except (IndexError, KeyError, AttributeError):
begin = _infer_value(inputs[1], params).asnumpy().tolist()[0]
# Handle symbolic begin
try:
begin = _infer_value(inputs[1], params).asnumpy().tolist()
except Exception:
begin = inputs[1]
try:
size = _get_list_param(params, inputs[2])
except (IndexError, KeyError, AttributeError):
# Handle symbolic size
try:
size = _infer_value(inputs[2], params).asnumpy().tolist()[0]
size = _infer_value(inputs[2], params).asnumpy().tolist()
except Exception:
size = inputs[2]
data_shape = _infer_shape(inputs[0], mod)
data_dim = len(data_shape)
end = size
if not isinstance(end, (_expr.Call, _expr.Var)):
for i in range(data_dim):
if size[i] == -1:
end[i] = data_shape[i]
else:
end[i] += begin[i]
return _op.strided_slice(inputs[0], begin=begin, end=end)
return _op.strided_slice(inputs[0], begin=begin, end=size, slice_mode="size")
return _impl


Expand Down Expand Up @@ -1466,8 +1517,11 @@ def _transform_mask(stride_dim, ellipsis_mask):
fshape_indices = None
if begin_mask or end_mask or ellipsis_mask or new_axis_mask or shrink_axis_mask:
begin, end, stride, fshape_indices = _transform_mask(stride_dim, ellipsis_mask)
out = _op.strided_slice(inputs[0], begin=begin, end=end, strides=stride)
out_shape = _infer_shape(out, mod)
out = _op.strided_slice(inputs[0],
begin=begin,
end=end,
strides=stride)
out_shape = _infer_shape(out, mod=mod)
if not fshape_indices:
fshape_indices = range(len(out_shape))

Expand Down Expand Up @@ -2027,6 +2081,8 @@ def _impl(inputs, attr, params, mod):
'Mod' : _elemwise('mod'),
'Mul' : _elemwise('multiply'),
'Neg' : AttrCvt('negative'),
'NonMaxSuppressionV2' : _nms(),
'NonMaxSuppressionV3' : _nms(),
'NoOp' : _no_op(),
'NotEqual' : _broadcast('not_equal'),
'OneHot' : _one_hot(),
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -2439,7 +2439,7 @@ def convert_detection_postprocess(self, op):

ret = _op.vision.multibox_transform_loc(cls_pred, loc_prob,
anchor_expr, **multibox_transform_loc_attrs)
ret = _op.vision.non_max_suppression(ret[0], ret[1], **non_max_suppression_attrs)
ret = _op.vision.non_max_suppression(ret[0], ret[1], ret[1], **non_max_suppression_attrs)
ret = _op.vision.get_valid_counts(ret, 0)
valid_count = ret[0]
# keep only the top 'max_detections' rows
Expand Down
6 changes: 4 additions & 2 deletions python/tvm/relay/op/_tensor_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,8 +390,10 @@ def conv2d_grad(orig, grad):
assert padded_weight_grad_h >= filter_h
assert padded_weight_grad_w >= filter_w
if padded_weight_grad_h > filter_h or padded_weight_grad_w > filter_w:
backward_weight = strided_slice(backward_weight, begin=[0, 0, 0, 0],
end=[None, None, filter_h, filter_w])
backward_weight = strided_slice(backward_weight,
begin=const([0, 0, 0, 0], dtype="int64"),
end=const([out_channel, in_channel // attrs.groups,
filter_h, filter_w], dtype="int64"))

return [backward_data, backward_weight]

Expand Down
Loading