Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 34 additions & 10 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -2697,6 +2697,40 @@ def _impl_v10(cls, inputs, attr, params):

@classmethod
def _impl_v11(cls, inputs, attr, params):
scale = inputs[2]
scale_shape = infer_shape(scale)
if len(inputs) == 4:
assert (
len(scale_shape) == 0 or scale_shape[0] == 0
), "One of scale or size should be passed, not both."
size = inputs[3]
else:
assert len(scale_shape) != 0, "One of scale or size should be passed."
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One of the scales or sizes

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above!

size = _op.cast(shape_of(inputs[0]), infer_type(scale).checked_type.dtype) * scale
return cls.v11_13_common(inputs, size, attr, params)

@classmethod
def _impl_v13(cls, inputs, attr, params):
scale = inputs[2]
size = inputs[3]
if size is not None:
assert scale is None, "One of scale or size should be passed, not both."
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One of the scales or sizes

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

both scale and size are singular (there is only one of them) and adding an 's' makes them plural. It's only one dog or cat, not dogs or cats 🐶 🐈 vs 🐶 🐶 🐶 🐈 🐈 🐈

else:
scale_type = infer_type(scale)
scale_shape = scale_type.checked_type.shape
scale_dtype = scale_type.checked_type.dtype
assert len(scale_shape) != 0, "One of scale or size should be passed."
size = _op.cast(shape_of(inputs[0]), scale_dtype) * scale

return cls.v11_13_common(inputs, size, attr, params)

@classmethod
def v11_13_common(cls, inputs, size, attr, params):
"""
Resize v11 and Resize v13 are identical except in how
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resize v11 and resize

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi emijiayw, Thanks for your feedback! Resize v13 is specified for clarity since just resize could be misleading. :)

they handle the passing of scale and size. This utility
provides the implementation for both
"""
ndims = len(infer_shape(inputs[0]))
mode = attr.get("mode").decode("ascii")
if mode == "nearest":
Expand All @@ -2715,16 +2749,6 @@ def _impl_v11(cls, inputs, attr, params):
alpha = attr.get("cubic_coeff_a", -0.75)
exclude = attr.get("exclude_outside", 0)

scale = inputs[2]
scale_shape = infer_shape(scale)
if len(inputs) == 4:
assert (
len(scale_shape) == 0 or scale_shape[0] == 0
), "One of scale or size should be passed, not both."
size = inputs[3]
else:
assert len(scale_shape) != 0, "One of scale or size should be passed."
size = _op.cast(shape_of(inputs[0]), infer_type(scale).checked_type.dtype) * scale
out_size = fold_constant(_op.strided_slice(size, [2], [4]))
out = None
if ndims == 3:
Expand Down
9 changes: 1 addition & 8 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -3970,6 +3970,7 @@ def verify(ishape, oshape, scales, mode, coord_trans="asymmetric", alpha=0.5, ex
make_constant_node("scales", onnx.TensorProto.FLOAT, (len(scales),), scales),
]
input_names = ["X", "roi", "scales"]

if oshape != []:
nodes.append(
make_constant_node("sizes", onnx.TensorProto.INT64, (len(oshape),), oshape)
Expand Down Expand Up @@ -4954,15 +4955,7 @@ def verify_eyelike(indata):
"test_reduce_sum_keepdims_random",
"test_reduce_sum_negative_axes_keepdims_example",
"test_reduce_sum_negative_axes_keepdims_random",
"test_resize_downsample_sizes_cubic",
"test_resize_downsample_sizes_linear_pytorch_half_pixel",
"test_resize_downsample_sizes_nearest",
"test_resize_tf_crop_and_resize",
"test_resize_upsample_sizes_cubic",
"test_resize_upsample_sizes_nearest",
"test_resize_upsample_sizes_nearest_ceil_half_pixel",
"test_resize_upsample_sizes_nearest_floor_align_corners",
"test_resize_upsample_sizes_nearest_round_prefer_ceil_asymmetric",
"test_rnn_seq_length",
"test_round",
"test_scan9_sum",
Expand Down