From 44331d8e8e87ba7b8a02028ac573fdbbec11bb49 Mon Sep 17 00:00:00 2001 From: masahi Date: Tue, 17 Dec 2019 14:38:19 +0900 Subject: [PATCH 01/14] add onnx resize converter --- include/tvm/relay/attrs/image.h | 4 ++-- python/tvm/relay/frontend/onnx.py | 36 ++++++++++++++++++++++++++++- python/tvm/relay/op/image/_image.py | 4 ++-- python/tvm/relay/op/image/image.py | 4 ++-- src/relay/op/image/resize.cc | 4 ++-- topi/python/topi/image/resize.py | 26 +++++++++++++++------ 6 files changed, 62 insertions(+), 16 deletions(-) diff --git a/include/tvm/relay/attrs/image.h b/include/tvm/relay/attrs/image.h index dd3a0aa0cc65..63567812be9c 100644 --- a/include/tvm/relay/attrs/image.h +++ b/include/tvm/relay/attrs/image.h @@ -36,7 +36,7 @@ struct ResizeAttrs : public tvm::AttrsNode { Array size; std::string layout; std::string method; - bool align_corners; + std::string coordinate_transformation_mode; DataType out_dtype; TVM_DECLARE_ATTRS(ResizeAttrs, "relay.attrs.ResizeAttrs") { @@ -52,7 +52,7 @@ struct ResizeAttrs : public tvm::AttrsNode { "nearest_neighbor - Nearest Neighbor" "bilinear - Bilinear Interpolation" "bicubic - Bicubic Interpolation"); - TVM_ATTR_FIELD(align_corners).set_default(true) + TVM_ATTR_FIELD(coordinate_transformation_mode).set_default("half_pixel") .describe("Should be true to preserve the values at the corner pixels"); TVM_ATTR_FIELD(out_dtype) .set_default(NullValue()) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index c7764db729ee..fa8b0f091cc1 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -1091,6 +1091,7 @@ class Or(Elemwise): def _impl_v7(cls, inputs, attr, params): return _op.logical_or(inputs[0], inputs[1]) + class Expand(OnnxOpConverter): """ Operator converter for Expand. """ @@ -1138,6 +1139,38 @@ def expand_shape(in_shape, shape): shape = expand_shape(in_shape, shape) return _op.broadcast_to(inputs[0], shape=tuple(shape)) + +class Resize(OnnxOpConverter): + @classmethod + def _impl_v11(cls, inputs, attr, params): + mode = attr.get('mode') + if mode == b'nearest': + method = "nearest_neighbor" + elif mode == b'linear': + method = "bilinear" + else: + raise tvm.error.OpAttributeInvalid( + 'Value {} in attribute "mode" of operator Resize is not valid.'.format(mode)) + scale = infer_value_simulated(inputs[2], params).asnumpy() + size = infer_value_simulated(inputs[3], params).asnumpy() + in_size = np.array(infer_shape(inputs[0])) + if len(scale) != 0: + assert len(size) == 0 + size = in_size * scale + else: + assert len(size) != 0 + coord_trans = attr.get('coordinate_transformation_mode') + if coord_trans in [b'pytorch_half_pixel', b'half_pixel']: + coord_trans = "half_pixel" + elif coord_trans == b'align_corners': + coord_trans = "align_corners" + elif coord_trans == b'asymmetric' or method == "nearest_neighbor": + coord_trans = "asymmetric" + else: + raise tvm.error.OpAttributeInvalid( + 'Unsupported coordinate_transformation_mode: {}'.format(coord_trans)) + return _op.image.resize(inputs[0], (size[-2], size[-1]), "NCHW", method, coord_trans) + # compatible operators that do NOT require any conversion. _identity_list = [] @@ -1263,7 +1296,8 @@ def _get_convert_map(opset): 'Tile': Tile.get_converter(opset), 'Erf': Erf.get_converter(opset), 'Where': Where.get_converter(opset), - 'Or': Or.get_converter(opset) + 'Or': Or.get_converter(opset), + 'Resize': Resize.get_converter(opset), } diff --git a/python/tvm/relay/op/image/_image.py b/python/tvm/relay/op/image/_image.py index fcebfd8c9613..776435ada497 100644 --- a/python/tvm/relay/op/image/_image.py +++ b/python/tvm/relay/op/image/_image.py @@ -31,6 +31,6 @@ def compute_resize(attrs, inputs, out_type, target): size = attrs.size layout = attrs.layout method = attrs.method - align_corners = attrs.align_corners + coord_trans = attrs.coordinate_transformation_mode out_dtype = attrs.out_dtype - return [topi.image.resize(inputs[0], size, layout, method, align_corners, out_dtype)] + return [topi.image.resize(inputs[0], size, layout, method, coord_trans, out_dtype)] diff --git a/python/tvm/relay/op/image/image.py b/python/tvm/relay/op/image/image.py index c54e438dce51..758e521e2bdf 100644 --- a/python/tvm/relay/op/image/image.py +++ b/python/tvm/relay/op/image/image.py @@ -22,7 +22,7 @@ def resize(data, size, layout="NCHW", method="bilinear", - align_corners=True, + coordinate_transformation_mode="half_pixel", out_dtype=None): """Image resize operator. @@ -59,4 +59,4 @@ def resize(data, result: relay.Expr The resized result. """ - return _make.resize(data, size, layout, method, align_corners, out_dtype) + return _make.resize(data, size, layout, method, coordinate_transformation_mode, out_dtype) diff --git a/src/relay/op/image/resize.cc b/src/relay/op/image/resize.cc index f6329f7af709..baab0ead692f 100644 --- a/src/relay/op/image/resize.cc +++ b/src/relay/op/image/resize.cc @@ -71,13 +71,13 @@ Expr MakeResize(Expr data, Array size, std::string layout, std::string method, - bool align_corners, + std::string coordinate_transformation_mode, DataType out_dtype) { auto attrs = make_object(); attrs->size = std::move(size); attrs->layout = std::move(layout); attrs->method = std::move(method); - attrs->align_corners = align_corners; + attrs->coordinate_transformation_mode = coordinate_transformation_mode; attrs->out_dtype = out_dtype; static const Op& op = Op::Get("image.resize"); return CallNode::make(op, {data}, Attrs(attrs), {}); diff --git a/topi/python/topi/image/resize.py b/topi/python/topi/image/resize.py index 27bea9434348..82bd45ffe9b1 100644 --- a/topi/python/topi/image/resize.py +++ b/topi/python/topi/image/resize.py @@ -21,7 +21,8 @@ from .. import tag -def resize(data, size, layout="NCHW", method="bilinear", align_corners=True, out_dtype=None): +def resize(data, size, layout="NCHW", method="bilinear", + coordinate_transformation_mode="half_pixel", out_dtype=None): """Perform resize operation on the data. Parameters @@ -66,12 +67,15 @@ def resize(data, size, layout="NCHW", method="bilinear", align_corners=True, out in_n, in_c, in_h, in_w, in_cc = data.shape output_shape = [in_n, in_c, size[0], size[1], in_cc] - if align_corners: + if coordinate_transformation_mode == "align_corners": y_ratio = (in_h - 1).astype('float') / (size[0] - 1) x_ratio = (in_w - 1).astype('float') / (size[1] - 1) - else: + elif coordinate_transformation_mode in ["asymmetric", "half_pixel"]: y_ratio = (in_h).astype('float') / (size[0]) x_ratio = (in_w).astype('float') / (size[1]) + else: + raise ValueError("Unsupported coordinate_transformation_mode: {}".format( + coordinate_transformation_mode)) def _get_pixel(n, c, y, x, cc): y = tvm.max(tvm.min(y, in_h - 1), 0) @@ -127,8 +131,12 @@ def _lerp(A, B, t): def _bilinear(*indices): n, c, y, x, cc = _get_indices(*indices) - in_y = y_ratio * y - in_x = x_ratio * x + if coordinate_transformation_mode == "half_pixel": + in_y = y_ratio * (y + 0.5) - 0.5 + in_x = x_ratio * (x + 0.5) - 0.5 + else: + in_y = y_ratio * y + in_x = x_ratio * x xint = tvm.floor(in_x).astype('int32') xfract = in_x - tvm.floor(in_x) @@ -158,8 +166,12 @@ def _cubic_kernel(A, B, C, D, t): def _bicubic(*indices): n, c, y, x, cc = _get_indices(*indices) - in_y = y_ratio * y - in_x = x_ratio * x + if coordinate_transformation_mode == "half_pixel": + in_y = y_ratio * (y + 0.5) - 0.5 + in_x = x_ratio * (x + 0.5) - 0.5 + else: + in_y = y_ratio * y + in_x = x_ratio * x xint = tvm.floor(in_x).astype('int32') xfract = in_x - tvm.floor(in_x) From e0b0f5fbbbeda61da5a8e09671ca48b1a10fed4d Mon Sep 17 00:00:00 2001 From: masahi Date: Tue, 17 Dec 2019 15:08:40 +0900 Subject: [PATCH 02/14] update frontends --- python/tvm/relay/frontend/mxnet.py | 2 +- python/tvm/relay/frontend/tensorflow.py | 2 +- python/tvm/relay/frontend/tflite.py | 3 ++- python/tvm/relay/op/image/image.py | 4 ++-- 4 files changed, 6 insertions(+), 5 deletions(-) diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py index a1a357883a83..1f85277712aa 100644 --- a/python/tvm/relay/frontend/mxnet.py +++ b/python/tvm/relay/frontend/mxnet.py @@ -676,7 +676,7 @@ def _mx_resize(inputs, attrs): if scale_width is not None: width = (scale_width * shape[3]).astype("int32") size = (height, width) - return _op.image.resize(inputs[0], size, align_corners=True) + return _op.image.resize(inputs[0], size, coordinate_transformation_mode="align_corners") def _mx_roi_pooling(inputs, attrs): new_attrs = {} diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index db037e49bded..c2c5b7710196 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -582,7 +582,7 @@ def _impl(inputs, attr, params): raise tvm.error.OpAttributeUnImplemented( 'Attribute method=nearest is not supported') else: - attrs['align_corners'] = True + attrs['coordinate_transformation_mode'] = 'align_corners' attrs['method'] = 'bilinear' out = None diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index be18bf622196..8882d4079202 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -330,7 +330,8 @@ def _convert_resize(self, method, op): align_corners = resize_options.AlignCorners() # Use layout NHWC - out = _op.image.resize(in_expr, target_size, "NHWC", method, align_corners) + coord_trans = "align_corners" if align_corners else "asymmetric" + out = _op.image.resize(in_expr, target_size, "NHWC", method, coordinate_transformation_mode=coord_trans) return out def convert_resize_bilinear(self, op): diff --git a/python/tvm/relay/op/image/image.py b/python/tvm/relay/op/image/image.py index 758e521e2bdf..a35f6bcb6b05 100644 --- a/python/tvm/relay/op/image/image.py +++ b/python/tvm/relay/op/image/image.py @@ -48,8 +48,8 @@ def resize(data, method : str, optional Scale method to used [nearest_neighbor, bilinear, bicubic]. - align_corners : int, optional - Should be true to preserve the values at the corner pixels + coordinate_transformation_mode : string, optional + TODO out_dtype : str, optional Type to return. If left None returns the same type as input. From 0da175be400c157ac2d1476229986b5f4c1da39f Mon Sep 17 00:00:00 2001 From: masahi Date: Tue, 17 Dec 2019 16:10:07 +0900 Subject: [PATCH 03/14] updating topi --- include/tvm/relay/attrs/image.h | 2 +- python/tvm/relay/frontend/onnx.py | 7 ++++--- topi/python/topi/nn/upsampling.py | 3 ++- topi/tests/python/test_topi_resize.py | 4 ++-- 4 files changed, 9 insertions(+), 7 deletions(-) diff --git a/include/tvm/relay/attrs/image.h b/include/tvm/relay/attrs/image.h index 63567812be9c..b3e7a94b3ba8 100644 --- a/include/tvm/relay/attrs/image.h +++ b/include/tvm/relay/attrs/image.h @@ -53,7 +53,7 @@ struct ResizeAttrs : public tvm::AttrsNode { "bilinear - Bilinear Interpolation" "bicubic - Bicubic Interpolation"); TVM_ATTR_FIELD(coordinate_transformation_mode).set_default("half_pixel") - .describe("Should be true to preserve the values at the corner pixels"); + .describe("TODO"); TVM_ATTR_FIELD(out_dtype) .set_default(NullValue()) .describe("Output data type."); diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index fa8b0f091cc1..f8b9903a9545 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -1155,10 +1155,10 @@ def _impl_v11(cls, inputs, attr, params): size = infer_value_simulated(inputs[3], params).asnumpy() in_size = np.array(infer_shape(inputs[0])) if len(scale) != 0: - assert len(size) == 0 + assert len(size) == 0, "One of scale or size should be passed, not both" size = in_size * scale else: - assert len(size) != 0 + assert len(size) != 0, "One of scale or size should be passed, not both" coord_trans = attr.get('coordinate_transformation_mode') if coord_trans in [b'pytorch_half_pixel', b'half_pixel']: coord_trans = "half_pixel" @@ -1169,7 +1169,8 @@ def _impl_v11(cls, inputs, attr, params): else: raise tvm.error.OpAttributeInvalid( 'Unsupported coordinate_transformation_mode: {}'.format(coord_trans)) - return _op.image.resize(inputs[0], (size[-2], size[-1]), "NCHW", method, coord_trans) + layout = "NHWC" if in_size[-1] == size[-1] else "NCHW" + return _op.image.resize(inputs[0], (size[-2], size[-1]), layout, method, coord_trans) # compatible operators that do NOT require any conversion. _identity_list = [] diff --git a/topi/python/topi/nn/upsampling.py b/topi/python/topi/nn/upsampling.py index fe63e474f2bf..c816bbb3c04e 100644 --- a/topi/python/topi/nn/upsampling.py +++ b/topi/python/topi/nn/upsampling.py @@ -61,8 +61,9 @@ def upsampling(data, scale_h, scale_w, layout="NCHW", method='nearest_neighbor', else: raise ValueError("not support this layout {} yet".format(layout)) + coord_trans = "align_corners" if align_corners else "asymmetric" return topi.image.resize(data, out_shape, layout=layout, - method=method, align_corners=align_corners) + method=method, coordinate_transformation_mode=coord_trans) def upsampling3d(data, scale_d, scale_h, scale_w, layout="NCDHW", method='nearest_neighbor', diff --git a/topi/tests/python/test_topi_resize.py b/topi/tests/python/test_topi_resize.py index 10678a0c2600..12a5065779f4 100644 --- a/topi/tests/python/test_topi_resize.py +++ b/topi/tests/python/test_topi_resize.py @@ -37,8 +37,8 @@ def verify_resize(batch, in_channel, in_height, in_width, out_height, out_width, else: raise NotImplementedError( 'Layout not supported {} '.format(layout)) - - B = topi.image.resize(A, (out_height, out_width), layout=layout, align_corners=align_corners, method=method) + coord_trans = "align_corners" if align_corners else "asymmetric" + B = topi.image.resize(A, (out_height, out_width), layout=layout, coordinate_transformation_mode=coord_trans, method=method) if method == "bilinear": b_np = topi.testing.bilinear_resize_python(a_np, (out_height, out_width), layout, align_corners) From 68df7e7577b34c1adb28960182e085b58a11249c Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 18 Dec 2019 05:53:47 +0900 Subject: [PATCH 04/14] adding onnx resize tests --- python/tvm/relay/frontend/onnx.py | 2 +- tests/python/frontend/onnx/test_forward.py | 73 +++++++++++++++++----- topi/python/topi/image/resize.py | 6 +- 3 files changed, 60 insertions(+), 21 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index f8b9903a9545..0627f4a24326 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -1156,7 +1156,7 @@ def _impl_v11(cls, inputs, attr, params): in_size = np.array(infer_shape(inputs[0])) if len(scale) != 0: assert len(size) == 0, "One of scale or size should be passed, not both" - size = in_size * scale + size = (in_size * scale).astype(np.int64) else: assert len(size) != 0, "One of scale or size should be passed, not both" coord_trans = attr.get('coordinate_transformation_mode') diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index a35ebd23ae0a..fd0e03b381e9 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -98,23 +98,6 @@ def verify_onnx_forward_impl(graph_file, data_shape, out_shape): tvm.testing.assert_allclose(c2_out, tvm_out, rtol=1e-5, atol=1e-5) -def verify_super_resolution_example(): - verify_onnx_forward_impl( - super_resolution, (1, 1, 224, 224), (1, 1, 672, 672)) - - -def verify_squeezenet1_1(): - verify_onnx_forward_impl(squeezenet1_1, (1, 3, 224, 224), (1, 1000)) - - -def verify_lenet(): - verify_onnx_forward_impl(lenet, (1, 1, 28, 28), (1, 10)) - - -def verify_resnet18(): - verify_onnx_forward_impl(resnet18_1_0, (1, 3, 224, 224), (1, 1000)) - - def test_reshape(): in_shape = (4, 3, 3, 4) ref_shape = (6, 2, 4, 3) @@ -1826,6 +1809,7 @@ def test_convtranspose(): verify_convtranspose((1, 1, 3, 3), (1, 2, 3, 3), (1, 2, 7, 3), [1, 2, 1, 2]) +<<<<<<< HEAD def test_unsqueeze_constant(): from torch.nn import Linear, Sequential, Module class Flatten(Module): @@ -1842,6 +1826,60 @@ def forward(self, input): onnx_model = onnx.load(file_name) relay.frontend.from_onnx(onnx_model, {'0': input_size}) +======= +def test_resize(): + def make_constant_node(name, data_type, dims, vals): + return helper.make_node('Constant', + inputs=[], + outputs=[name], + value=helper.make_tensor(name=name, + data_type=data_type, + dims=dims, + vals=vals)) + def verify(ishape, oshape, scales, mode, coord_trans): + roi_node = make_constant_node('roi', onnx.TensorProto.FLOAT, (0,), []) + scales_node = make_constant_node('scales', onnx.TensorProto.FLOAT, (len(scales),), scales) + sizes_node = make_constant_node('sizes', onnx.TensorProto.INT64, (len(oshape),), oshape) + resize_node = helper.make_node( + 'Resize', + inputs=['X', 'roi', 'scales', 'sizes'], + outputs=['Y'], + mode=mode, + coordinate_transformation_mode=coord_trans + ) + + if oshape == []: + oshape = [dim * scale for (dim, scale) in zip(ishape, scales)] + + graph = helper.make_graph([roi_node, scales_node, sizes_node, resize_node], + "resize_test", + inputs=[helper.make_tensor_value_info("X", TensorProto.FLOAT, ishape)], + outputs=[helper.make_tensor_value_info("Y", TensorProto.FLOAT, oshape)]) + + model = helper.make_model(graph, producer_name='resize_test') + + for target, ctx in ctx_list(): + x = np.random.uniform(size=ishape).astype('float32') + tvm_out = get_tvm_output(model, x, target, ctx, oshape, 'float32', opset=11) + onnx_out = get_onnxruntime_output(model, x, 'float32') + + tvm.testing.assert_allclose(onnx_out, tvm_out, rtol=1e-05, atol=1e-05) + + # NCHW and upsampling + verify([1, 16, 32, 32], [1, 16, 64, 64], [], "nearest", "asymmetric") + verify([1, 16, 32, 32], [1, 16, 64, 64], [], "linear", "align_corners") + verify([1, 16, 32, 32], [1, 16, 64, 64], [], "linear", "half_pixel") + # NCHW and downsampling + verify([1, 16, 32, 32], [1, 16, 16, 16], [], "nearest", "asymmetric") + verify([1, 16, 32, 32], [1, 16, 16, 16], [], "linear", "align_corners") + verify([1, 16, 32, 32], [1, 16, 16, 16], [], "linear", "half_pixel") + # NHWC and upsampling + # verify([1, 32, 32, 16], [1, 64, 64, 16], [], "nearest", "asymmetric") + # verify([1, 32, 32, 16], [1, 64, 64, 16], [], "linear", "align_corners") + # verify([1, 32, 32, 16], [1, 64, 64, 16], [], "linear", "half_pixel") + # scales are specified instead of sizes + # verify([1, 16, 32, 32], [], [1, 1, 2, 2], "nearest", "asymmetric") +>>>>>>> adding onnx resize tests if __name__ == '__main__': @@ -1901,3 +1939,4 @@ def forward(self, input): test_conv() test_convtranspose() test_unsqueeze_constant() + test_resize() diff --git a/topi/python/topi/image/resize.py b/topi/python/topi/image/resize.py index 82bd45ffe9b1..1815ef4b56e5 100644 --- a/topi/python/topi/image/resize.py +++ b/topi/python/topi/image/resize.py @@ -38,8 +38,8 @@ def resize(data, size, layout="NCHW", method="bilinear", layout: string, optional "NCHW", "NHWC", or "NCHWc". - align_corners: Boolean, optional - To preserve the values at the corner pixels. + coordinate_transformation_mode: string, optional + TODO method: {"bilinear", "nearest_neighbor", "bicubic"} Method to be used for resizing. @@ -113,7 +113,7 @@ def _nearest_neighbor(*indices): in_y = y_ratio * y in_x = x_ratio * x - if align_corners: + if coordinate_transformation_mode == "align_corners": yint = tvm.round(in_y).astype('int32') xint = tvm.round(in_x).astype('int32') else: From aed9569467da0b7e733c9f7f9efb09a4194cf2d0 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 18 Dec 2019 06:59:05 +0900 Subject: [PATCH 05/14] fixed NHWC test by casting size dtype to int32 --- python/tvm/relay/frontend/onnx.py | 22 +++++++++----- tests/python/frontend/onnx/test_forward.py | 35 +++++++++++----------- 2 files changed, 32 insertions(+), 25 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 0627f4a24326..eddfd701ea49 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -1151,14 +1151,16 @@ def _impl_v11(cls, inputs, attr, params): else: raise tvm.error.OpAttributeInvalid( 'Value {} in attribute "mode" of operator Resize is not valid.'.format(mode)) - scale = infer_value_simulated(inputs[2], params).asnumpy() - size = infer_value_simulated(inputs[3], params).asnumpy() + in_size = np.array(infer_shape(inputs[0])) - if len(scale) != 0: - assert len(size) == 0, "One of scale or size should be passed, not both" - size = (in_size * scale).astype(np.int64) + scale = infer_value_simulated(inputs[2], params).asnumpy() + if len(inputs) == 4: + assert len(scale) == 0, "One of scale or size should be passed, not both." + size = infer_value_simulated(inputs[3], params).asnumpy().astype(np.int32) else: - assert len(size) != 0, "One of scale or size should be passed, not both" + assert len(scale) != 0, "One of scale or size should be passed." + size = (in_size * scale).astype(np.int64) + coord_trans = attr.get('coordinate_transformation_mode') if coord_trans in [b'pytorch_half_pixel', b'half_pixel']: coord_trans = "half_pixel" @@ -1169,8 +1171,12 @@ def _impl_v11(cls, inputs, attr, params): else: raise tvm.error.OpAttributeInvalid( 'Unsupported coordinate_transformation_mode: {}'.format(coord_trans)) - layout = "NHWC" if in_size[-1] == size[-1] else "NCHW" - return _op.image.resize(inputs[0], (size[-2], size[-1]), layout, method, coord_trans) + layout = "NCHW" + out_size = (size[2], size[3]) + if in_size[-1] == size[-1]: + layout = "NHWC" + out_size = (size[1], size[2]) + return _op.image.resize(inputs[0], out_size, layout, method, coord_trans) # compatible operators that do NOT require any conversion. _identity_list = [] diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index fd0e03b381e9..ecb44c5e97f1 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -1809,7 +1809,6 @@ def test_convtranspose(): verify_convtranspose((1, 1, 3, 3), (1, 2, 3, 3), (1, 2, 7, 3), [1, 2, 1, 2]) -<<<<<<< HEAD def test_unsqueeze_constant(): from torch.nn import Linear, Sequential, Module class Flatten(Module): @@ -1826,7 +1825,8 @@ def forward(self, input): onnx_model = onnx.load(file_name) relay.frontend.from_onnx(onnx_model, {'0': input_size}) -======= + + def test_resize(): def make_constant_node(name, data_type, dims, vals): return helper.make_node('Constant', @@ -1837,21 +1837,26 @@ def make_constant_node(name, data_type, dims, vals): dims=dims, vals=vals)) def verify(ishape, oshape, scales, mode, coord_trans): - roi_node = make_constant_node('roi', onnx.TensorProto.FLOAT, (0,), []) - scales_node = make_constant_node('scales', onnx.TensorProto.FLOAT, (len(scales),), scales) - sizes_node = make_constant_node('sizes', onnx.TensorProto.INT64, (len(oshape),), oshape) - resize_node = helper.make_node( + nodes = [ + make_constant_node('roi', onnx.TensorProto.FLOAT, (0,), []), + make_constant_node('scales', onnx.TensorProto.FLOAT, (len(scales),), scales) + ] + input_names = ['X', 'roi', 'scales'] + if oshape != []: + nodes.append(make_constant_node('sizes', onnx.TensorProto.INT64, (len(oshape),), oshape)) + input_names.append('sizes') + nodes.append(helper.make_node( 'Resize', - inputs=['X', 'roi', 'scales', 'sizes'], + inputs=input_names, outputs=['Y'], mode=mode, coordinate_transformation_mode=coord_trans - ) + )) if oshape == []: oshape = [dim * scale for (dim, scale) in zip(ishape, scales)] - graph = helper.make_graph([roi_node, scales_node, sizes_node, resize_node], + graph = helper.make_graph(nodes, "resize_test", inputs=[helper.make_tensor_value_info("X", TensorProto.FLOAT, ishape)], outputs=[helper.make_tensor_value_info("Y", TensorProto.FLOAT, oshape)]) @@ -1860,12 +1865,12 @@ def verify(ishape, oshape, scales, mode, coord_trans): for target, ctx in ctx_list(): x = np.random.uniform(size=ishape).astype('float32') - tvm_out = get_tvm_output(model, x, target, ctx, oshape, 'float32', opset=11) onnx_out = get_onnxruntime_output(model, x, 'float32') + tvm_out = get_tvm_output(model, x, target, ctx, oshape, 'float32', opset=11) tvm.testing.assert_allclose(onnx_out, tvm_out, rtol=1e-05, atol=1e-05) - # NCHW and upsampling + # # NCHW and upsampling verify([1, 16, 32, 32], [1, 16, 64, 64], [], "nearest", "asymmetric") verify([1, 16, 32, 32], [1, 16, 64, 64], [], "linear", "align_corners") verify([1, 16, 32, 32], [1, 16, 64, 64], [], "linear", "half_pixel") @@ -1873,13 +1878,9 @@ def verify(ishape, oshape, scales, mode, coord_trans): verify([1, 16, 32, 32], [1, 16, 16, 16], [], "nearest", "asymmetric") verify([1, 16, 32, 32], [1, 16, 16, 16], [], "linear", "align_corners") verify([1, 16, 32, 32], [1, 16, 16, 16], [], "linear", "half_pixel") - # NHWC and upsampling - # verify([1, 32, 32, 16], [1, 64, 64, 16], [], "nearest", "asymmetric") - # verify([1, 32, 32, 16], [1, 64, 64, 16], [], "linear", "align_corners") - # verify([1, 32, 32, 16], [1, 64, 64, 16], [], "linear", "half_pixel") # scales are specified instead of sizes - # verify([1, 16, 32, 32], [], [1, 1, 2, 2], "nearest", "asymmetric") ->>>>>>> adding onnx resize tests + verify([1, 16, 32, 32], [], [1, 1, 2, 2], "nearest", "asymmetric") + verify([1, 16, 32, 32], [], [1, 1, 0.5, 0.5], "linear", "half_pixel") if __name__ == '__main__': From 48594322db4d499fd3748e04e7c724a5475c703f Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 18 Dec 2019 07:32:46 +0900 Subject: [PATCH 06/14] fix tests --- tests/python/frontend/onnx/test_forward.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index ecb44c5e97f1..d915acb439c6 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -1854,7 +1854,7 @@ def verify(ishape, oshape, scales, mode, coord_trans): )) if oshape == []: - oshape = [dim * scale for (dim, scale) in zip(ishape, scales)] + oshape = [round(dim * scale) for (dim, scale) in zip(ishape, scales)] graph = helper.make_graph(nodes, "resize_test", From 1f066f064a535423078364f208b1ecefc3067f38 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 18 Dec 2019 07:50:55 +0900 Subject: [PATCH 07/14] fix lint --- python/tvm/relay/frontend/onnx.py | 2 ++ python/tvm/relay/frontend/tflite.py | 3 ++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index eddfd701ea49..f6bdcad90d17 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -1141,6 +1141,8 @@ def expand_shape(in_shape, shape): class Resize(OnnxOpConverter): + """Operator converter for Resize + """ @classmethod def _impl_v11(cls, inputs, attr, params): mode = attr.get('mode') diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 8882d4079202..e2e01e545340 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -331,7 +331,8 @@ def _convert_resize(self, method, op): # Use layout NHWC coord_trans = "align_corners" if align_corners else "asymmetric" - out = _op.image.resize(in_expr, target_size, "NHWC", method, coordinate_transformation_mode=coord_trans) + out = _op.image.resize(in_expr, target_size, "NHWC", method, + coordinate_transformation_mode=coord_trans) return out def convert_resize_bilinear(self, op): From 607622fca257e0be7905732458c182efac1f8f1c Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 18 Dec 2019 08:19:15 +0900 Subject: [PATCH 08/14] update existing test cases --- python/tvm/relay/frontend/onnx.py | 2 +- tests/python/relay/test_op_level5.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index f6bdcad90d17..1427ecb5d972 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -1161,7 +1161,7 @@ def _impl_v11(cls, inputs, attr, params): size = infer_value_simulated(inputs[3], params).asnumpy().astype(np.int32) else: assert len(scale) != 0, "One of scale or size should be passed." - size = (in_size * scale).astype(np.int64) + size = (in_size * scale).astype(np.int32) coord_trans = attr.get('coordinate_transformation_mode') if coord_trans in [b'pytorch_half_pixel', b'half_pixel']: diff --git a/tests/python/relay/test_op_level5.py b/tests/python/relay/test_op_level5.py index 84e9f55d67e7..2f2e8523161c 100644 --- a/tests/python/relay/test_op_level5.py +++ b/tests/python/relay/test_op_level5.py @@ -39,7 +39,7 @@ def test_resize_infer_type(): assert zz.checked_type == relay.TensorType((n, c, th, tw), "int8") x = relay.var("x", relay.TensorType((n, c, h, w), "int8")) - z= relay.image.resize(x, (100, 200), "NCHW", "bilinear", True) + z= relay.image.resize(x, (100, 200), "NCHW", "bilinear", "align_corners") assert "size=" in z.astext() zz = run_infer_type(z) assert zz.checked_type == relay.TensorType((n, c, 100, 200), "int8") @@ -57,7 +57,7 @@ def verify_resize(dshape, scale, method, layout): else: ref_res = topi.testing.upsampling_python(x_data, (scale, scale), layout) x = relay.var("x", relay.TensorType(dshape, "float32")) - z = relay.image.resize(x, size, layout, method, True) + z = relay.image.resize(x, size, layout, method, "align_corners") assert "size=" in z.astext() zz = run_infer_type(z) assert zz.checked_type == relay.TensorType(ref_res.shape, "float32") From a5d5e564576a910eb9741bcfbb756e33c611063d Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 18 Dec 2019 10:03:46 +0900 Subject: [PATCH 09/14] fix tensorflow frontend --- python/tvm/relay/frontend/tensorflow.py | 4 ++++ topi/python/topi/testing/bilinear_resize_python.py | 14 ++++++++++---- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index c2c5b7710196..c430ce38996a 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -632,6 +632,10 @@ def _impl(inputs, attr, params): inputs.pop(1) # NHWC attr['layout'] = 'NHWC' + if attr.pop('align_corners') == True: + attr['coordinate_transformation_mode'] = 'align_corners' + else: + attr['coordinate_transformation_mode'] = 'asymmetric' # Ignore the new attributes from TF2.0, for now. return AttrCvt(op_name='resize', diff --git a/topi/python/topi/testing/bilinear_resize_python.py b/topi/python/topi/testing/bilinear_resize_python.py index 86dd450a88e2..f15c0d12accd 100644 --- a/topi/python/topi/testing/bilinear_resize_python.py +++ b/topi/python/topi/testing/bilinear_resize_python.py @@ -19,7 +19,7 @@ import math import numpy as np -def bilinear_resize_python(image, out_size, layout, align_corners=True): +def bilinear_resize_python(image, out_size, layout, coordinate_transformation_mode="align_corners"): """ Bilinear scaling using python""" (new_h, new_w) = out_size @@ -30,7 +30,7 @@ def bilinear_resize_python(image, out_size, layout, align_corners=True): (batch, channel, h, w) = image.shape scaled_image = np.ones((batch, channel, new_h, new_w)) - if align_corners: + if coordinate_transformation_mode == "align_corners": height_scale = np.float32(h-1) / np.float32(out_size[0]-1) width_scale = np.float32(w-1) / np.float32(out_size[1]-1) else: @@ -41,7 +41,10 @@ def bilinear_resize_python(image, out_size, layout, align_corners=True): for i in range(channel): for j in range(new_h): for k in range(new_w): - in_y = j * height_scale + if coordinate_transformation_mode == "half_pixel": + in_y = (j + 0.5) * height_scale - 0.5 + else: + in_y = j * height_scale y0 = math.floor(in_y) y1 = min(math.ceil(in_y), h - 1) y_lerp = in_y - y0 @@ -49,7 +52,10 @@ def bilinear_resize_python(image, out_size, layout, align_corners=True): y0 = int(y0) y1 = int(y1) - in_x = k * width_scale + if coordinate_transformation_mode == "half_pixel": + in_x = (k + 0.5) * width_scale - 0.5 + else: + in_x = k * width_scale x0 = math.floor(in_x) x1 = min(math.ceil(in_x), w - 1) x_lerp = in_x - x0 From a28db326c0ae45365338560e8a1347f1030243ea Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 18 Dec 2019 10:36:59 +0900 Subject: [PATCH 10/14] fix lint --- python/tvm/relay/frontend/tensorflow.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index c430ce38996a..8a6e5b778283 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -632,7 +632,7 @@ def _impl(inputs, attr, params): inputs.pop(1) # NHWC attr['layout'] = 'NHWC' - if attr.pop('align_corners') == True: + if attr.pop('align_corners') is True: attr['coordinate_transformation_mode'] = 'align_corners' else: attr['coordinate_transformation_mode'] = 'asymmetric' From 073464a9b8470ca4e26516dc54d422a2d84ce331 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 18 Dec 2019 11:27:24 +0900 Subject: [PATCH 11/14] remove NHWC stuff --- python/tvm/relay/frontend/onnx.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 1427ecb5d972..4809100f3c2c 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -1173,11 +1173,8 @@ def _impl_v11(cls, inputs, attr, params): else: raise tvm.error.OpAttributeInvalid( 'Unsupported coordinate_transformation_mode: {}'.format(coord_trans)) - layout = "NCHW" + layout = "NCHW" # ONNX assumes NCHW layout out_size = (size[2], size[3]) - if in_size[-1] == size[-1]: - layout = "NHWC" - out_size = (size[1], size[2]) return _op.image.resize(inputs[0], out_size, layout, method, coord_trans) # compatible operators that do NOT require any conversion. From 4a18becf5ed62a30639103d2f30af14003780613 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 18 Dec 2019 12:15:19 +0900 Subject: [PATCH 12/14] update topi resize test for half_pixel --- .../topi/testing/bilinear_resize_python.py | 29 +++++++++---------- topi/tests/python/test_topi_resize.py | 18 +++++++----- topi/tests/python/test_topi_upsampling.py | 2 +- 3 files changed, 25 insertions(+), 24 deletions(-) diff --git a/topi/python/topi/testing/bilinear_resize_python.py b/topi/python/topi/testing/bilinear_resize_python.py index f15c0d12accd..d324e2900c4f 100644 --- a/topi/python/topi/testing/bilinear_resize_python.py +++ b/topi/python/topi/testing/bilinear_resize_python.py @@ -37,6 +37,9 @@ def bilinear_resize_python(image, out_size, layout, coordinate_transformation_mo height_scale = np.float32(h) / np.float32(out_size[0]) width_scale = np.float32(w) / np.float32(out_size[1]) + def _lerp(A, B, t): + return A * (1.0 - t) + B * t + for b in range(batch): for i in range(channel): for j in range(new_h): @@ -45,23 +48,19 @@ def bilinear_resize_python(image, out_size, layout, coordinate_transformation_mo in_y = (j + 0.5) * height_scale - 0.5 else: in_y = j * height_scale - y0 = math.floor(in_y) - y1 = min(math.ceil(in_y), h - 1) - y_lerp = in_y - y0 - - y0 = int(y0) - y1 = int(y1) + y0 = int(math.floor(in_y)) + y1 = max(min(y0 + 1, h - 1), 0) + y0 = max(y0, 0) + y_lerp = in_y - math.floor(in_y) if coordinate_transformation_mode == "half_pixel": in_x = (k + 0.5) * width_scale - 0.5 else: in_x = k * width_scale - x0 = math.floor(in_x) - x1 = min(math.ceil(in_x), w - 1) - x_lerp = in_x - x0 - - x0 = int(x0) - x1 = int(x1) + x0 = int(math.floor(in_x)) + x1 = max(min(x0 + 1, w - 1), 0) + x0 = max(x0, 0) + x_lerp = in_x - math.floor(in_x) if layout == 'NHWC': A = image[b][y0][x0][i] @@ -74,10 +73,10 @@ def bilinear_resize_python(image, out_size, layout, coordinate_transformation_mo C = image[b][i][y1][x0] D = image[b][i][y1][x1] - top = A + (B - A) * x_lerp - bottom = C + (D - C) * x_lerp + top = _lerp(A, B, x_lerp) + bottom = _lerp(C, D, x_lerp) - pixel = np.float32(top + (bottom - top) * y_lerp) + pixel = np.float32(_lerp(top, bottom, y_lerp)) if layout == 'NHWC': scaled_image[b][j][k][i] = pixel diff --git a/topi/tests/python/test_topi_resize.py b/topi/tests/python/test_topi_resize.py index 12a5065779f4..206903ff1dc1 100644 --- a/topi/tests/python/test_topi_resize.py +++ b/topi/tests/python/test_topi_resize.py @@ -23,7 +23,8 @@ from common import get_all_backend -def verify_resize(batch, in_channel, in_height, in_width, out_height, out_width, layout='NCHW', align_corners=True, method="bilinear"): +def verify_resize(batch, in_channel, in_height, in_width, out_height, out_width, + layout='NCHW', coord_trans="align_corners", method="bilinear"): if layout == 'NCHW': A = tvm.placeholder((batch, in_channel, in_height, in_width), name='A', dtype='float32') dtype = A.dtype @@ -37,11 +38,9 @@ def verify_resize(batch, in_channel, in_height, in_width, out_height, out_width, else: raise NotImplementedError( 'Layout not supported {} '.format(layout)) - coord_trans = "align_corners" if align_corners else "asymmetric" B = topi.image.resize(A, (out_height, out_width), layout=layout, coordinate_transformation_mode=coord_trans, method=method) - if method == "bilinear": - b_np = topi.testing.bilinear_resize_python(a_np, (out_height, out_width), layout, align_corners) + b_np = topi.testing.bilinear_resize_python(a_np, (out_height, out_width), layout, coord_trans) else: scale_h = out_height / in_height scale_w = out_width / in_width @@ -70,14 +69,17 @@ def test_resize(): # Scale NCHW verify_resize(4, 16, 32, 32, 50, 50, 'NCHW') # Scale NCHW + Align Corners - verify_resize(6, 32, 64, 64, 20, 20, 'NCHW', True) + verify_resize(6, 32, 64, 64, 20, 20, 'NCHW') # Scale NHWC verify_resize(4, 16, 32, 32, 50, 50, "NHWC") # Scale NHWC + Align Corners - verify_resize(6, 32, 64, 64, 20, 20, "NHWC", True) + verify_resize(6, 32, 64, 64, 20, 20, "NHWC") # Nearest + Fractional - verify_resize(4, 16, 32, 32, 50, 50, 'NCHW', method="nearest_neighbor", align_corners=False) - verify_resize(4, 16, 32, 32, 50, 50, 'NHWC', method="nearest_neighbor", align_corners=False) + verify_resize(4, 16, 32, 32, 50, 50, 'NCHW', "asymmetric", method="nearest_neighbor") + verify_resize(4, 16, 32, 32, 50, 50, 'NHWC', "asymmetric", method="nearest_neighbor") + # half_pixel + verify_resize(4, 16, 16, 16, 32, 32, 'NCHW', "half_pixel", method="bilinear") + verify_resize(4, 16, 16, 16, 32, 32, 'NHWC', "half_pixel", method="bilinear") def verify_resize3d(batch, in_channel, in_depth, in_height, in_width, out_depth, out_height, out_width, diff --git a/topi/tests/python/test_topi_upsampling.py b/topi/tests/python/test_topi_upsampling.py index f5b77b1190a6..3aa67a5f78a4 100644 --- a/topi/tests/python/test_topi_upsampling.py +++ b/topi/tests/python/test_topi_upsampling.py @@ -43,7 +43,7 @@ def verify_upsampling(batch, in_channel, in_height, in_width, scale_h, scale_w, if method == "bilinear": out_size = (int(round(in_height*scale_h)), int(round(in_width*scale_w))) - b_np = topi.testing.bilinear_resize_python(a_np, out_size, layout, align_corners=False) + b_np = topi.testing.bilinear_resize_python(a_np, out_size, layout, "asymmetric") else: b_np = topi.testing.upsampling_python(a_np, (scale_h, scale_w), layout) From a94c141cd9e6e7d710c2797203af475b6fee9bf8 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 18 Dec 2019 12:41:26 +0900 Subject: [PATCH 13/14] update doc --- include/tvm/relay/attrs/image.h | 5 ++++- python/tvm/relay/op/image/image.py | 5 ++++- topi/python/topi/image/resize.py | 5 ++++- 3 files changed, 12 insertions(+), 3 deletions(-) diff --git a/include/tvm/relay/attrs/image.h b/include/tvm/relay/attrs/image.h index b3e7a94b3ba8..a2757bb79ae2 100644 --- a/include/tvm/relay/attrs/image.h +++ b/include/tvm/relay/attrs/image.h @@ -53,7 +53,10 @@ struct ResizeAttrs : public tvm::AttrsNode { "bilinear - Bilinear Interpolation" "bicubic - Bicubic Interpolation"); TVM_ATTR_FIELD(coordinate_transformation_mode).set_default("half_pixel") - .describe("TODO"); + .describe("Describes how to transform the coordinate in the resized tensor" + "to the coordinate in the original tensor." + "Refer to the ONNX Resize operator specification for details". + "Available options are half_pixel, align_corners and asymmetric"); TVM_ATTR_FIELD(out_dtype) .set_default(NullValue()) .describe("Output data type."); diff --git a/python/tvm/relay/op/image/image.py b/python/tvm/relay/op/image/image.py index a35f6bcb6b05..e0475a06025a 100644 --- a/python/tvm/relay/op/image/image.py +++ b/python/tvm/relay/op/image/image.py @@ -49,7 +49,10 @@ def resize(data, Scale method to used [nearest_neighbor, bilinear, bicubic]. coordinate_transformation_mode : string, optional - TODO + Describes how to transform the coordinate in the resized tensor + to the coordinate in the original tensor. + Refer to the ONNX Resize operator specification for details. + [half_pixel, align_corners, asymmetric] out_dtype : str, optional Type to return. If left None returns the same type as input. diff --git a/topi/python/topi/image/resize.py b/topi/python/topi/image/resize.py index 1815ef4b56e5..004e04a604e5 100644 --- a/topi/python/topi/image/resize.py +++ b/topi/python/topi/image/resize.py @@ -39,7 +39,10 @@ def resize(data, size, layout="NCHW", method="bilinear", "NCHW", "NHWC", or "NCHWc". coordinate_transformation_mode: string, optional - TODO + Describes how to transform the coordinate in the resized tensor + to the coordinate in the original tensor. + Refer to the ONNX Resize operator specification for details. + Available options are "half_pixel", "align_corners" and "asymmetric". method: {"bilinear", "nearest_neighbor", "bicubic"} Method to be used for resizing. From 4d238b9ee374201e161acce8f79a61b8066875d4 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 18 Dec 2019 12:56:25 +0900 Subject: [PATCH 14/14] fix doc --- include/tvm/relay/attrs/image.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/tvm/relay/attrs/image.h b/include/tvm/relay/attrs/image.h index a2757bb79ae2..87ad82d0293f 100644 --- a/include/tvm/relay/attrs/image.h +++ b/include/tvm/relay/attrs/image.h @@ -55,7 +55,7 @@ struct ResizeAttrs : public tvm::AttrsNode { TVM_ATTR_FIELD(coordinate_transformation_mode).set_default("half_pixel") .describe("Describes how to transform the coordinate in the resized tensor" "to the coordinate in the original tensor." - "Refer to the ONNX Resize operator specification for details". + "Refer to the ONNX Resize operator specification for details" "Available options are half_pixel, align_corners and asymmetric"); TVM_ATTR_FIELD(out_dtype) .set_default(NullValue())