diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index fdf6bd77ee46..e9612ef4386b 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -1395,6 +1395,44 @@ def _impl_v7(cls, inputs, attr, params): return _expr.TupleWrapper(_expr.Tuple((output, H_t, C_t)), 3) +class Resize(OnnxOpConverter): + """Operator converter for Resize + """ + @classmethod + def _impl_v11(cls, inputs, attr, params): + mode = attr.get('mode') + if mode == b'nearest': + method = "nearest_neighbor" + elif mode == b'linear': + method = "bilinear" + else: + raise tvm.error.OpAttributeInvalid( + 'Value {} in attribute "mode" of operator Resize is not valid.'.format(mode)) + + in_size = np.array(infer_shape(inputs[0])) + scale = infer_value_simulated(inputs[2], params).asnumpy() + if len(inputs) == 4: + assert len(scale) == 0, "One of scale or size should be passed, not both." + size = infer_value_simulated(inputs[3], params).asnumpy().astype(np.int32) + else: + assert len(scale) != 0, "One of scale or size should be passed." + size = (in_size * scale).astype(np.int32) + + coord_trans = attr.get('coordinate_transformation_mode') + if coord_trans in [b'pytorch_half_pixel', b'half_pixel']: + coord_trans = "half_pixel" + elif coord_trans == b'align_corners': + coord_trans = "align_corners" + elif coord_trans == b'asymmetric' or method == "nearest_neighbor": + coord_trans = "asymmetric" + else: + raise tvm.error.OpAttributeInvalid( + 'Unsupported coordinate_transformation_mode: {}'.format(coord_trans)) + layout = "NCHW" # ONNX assumes NCHW layout + out_size = (size[2], size[3]) + return _op.image.resize(inputs[0], out_size, layout, method, coord_trans) + + # compatible operators that do NOT require any conversion. _identity_list = [] @@ -1523,6 +1561,7 @@ def _get_convert_map(opset): 'Erf': Erf.get_converter(opset), 'Where': Where.get_converter(opset), 'Or': Or.get_converter(opset), + 'Resize': Resize.get_converter(opset), } diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index ef96c11c89ce..6243178dcb2b 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -2137,6 +2137,63 @@ def test_lstm(): use_peep=True) +def test_resize(): + def make_constant_node(name, data_type, dims, vals): + return helper.make_node('Constant', + inputs=[], + outputs=[name], + value=helper.make_tensor(name=name, + data_type=data_type, + dims=dims, + vals=vals)) + + def verify(ishape, oshape, scales, mode, coord_trans): + nodes = [ + make_constant_node('roi', onnx.TensorProto.FLOAT, (0,), []), + make_constant_node('scales', onnx.TensorProto.FLOAT, (len(scales),), scales) + ] + input_names = ['X', 'roi', 'scales'] + if oshape != []: + nodes.append(make_constant_node('sizes', onnx.TensorProto.INT64, (len(oshape),), oshape)) + input_names.append('sizes') + nodes.append(helper.make_node( + 'Resize', + inputs=input_names, + outputs=['Y'], + mode=mode, + coordinate_transformation_mode=coord_trans + )) + + if oshape == []: + oshape = [round(dim * scale) for (dim, scale) in zip(ishape, scales)] + + graph = helper.make_graph(nodes, + "resize_test", + inputs=[helper.make_tensor_value_info("X", TensorProto.FLOAT, ishape)], + outputs=[helper.make_tensor_value_info("Y", TensorProto.FLOAT, oshape)]) + + model = helper.make_model(graph, producer_name='resize_test') + + for target, ctx in ctx_list(): + x = np.random.uniform(size=ishape).astype('float32') + onnx_out = get_onnxruntime_output(model, x, 'float32') + tvm_out = get_tvm_output(model, x, target, ctx, oshape, 'float32', opset=11) + + tvm.testing.assert_allclose(onnx_out, tvm_out, rtol=1e-05, atol=1e-05) + + # upsampling + verify([1, 16, 32, 32], [1, 16, 64, 64], [], "nearest", "asymmetric") + verify([1, 16, 32, 32], [1, 16, 64, 64], [], "linear", "align_corners") + verify([1, 16, 32, 32], [1, 16, 64, 64], [], "linear", "half_pixel") + # downsampling + verify([1, 16, 32, 32], [1, 16, 16, 16], [], "nearest", "asymmetric") + verify([1, 16, 32, 32], [1, 16, 16, 16], [], "linear", "align_corners") + verify([1, 16, 32, 32], [1, 16, 16, 16], [], "linear", "half_pixel") + # scales are specified instead of sizes + verify([1, 16, 32, 32], [], [1, 1, 2, 2], "nearest", "asymmetric") + verify([1, 16, 32, 32], [], [1, 1, 0.5, 0.5], "linear", "half_pixel") + + if __name__ == '__main__': test_flatten() test_reshape() @@ -2196,3 +2253,4 @@ def test_lstm(): test_unsqueeze_constant() test_pooling() test_lstm() + test_resize()