From 8b8b16b33afe1f2338b513efebeaa1fa626ace11 Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Fri, 4 Dec 2020 09:20:10 -0800 Subject: [PATCH 1/4] Added maxunpool test. --- tests/python/frontend/onnx/test_forward.py | 198 +++++++++++++-------- 1 file changed, 122 insertions(+), 76 deletions(-) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 3ddc80af3a32..a93c76858393 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -3915,80 +3915,126 @@ def verify_size(indata): verify_size(input_data) +@tvm.testing.uses_gpu +def test_maxunpool(): + def verify_maxunpool(data, indices, kernel_shape, output_shape=None, pads=None, strides=None): + input_names = ["xT", "xI"] + input_info = [helper.make_tensor_value_info("xT", TensorProto.FLOAT, list(data.shape)), + helper.make_tensor_value_info("xI", TensorProto.INT64, list(indices.shape))] + input_values = [data, indices] + if output_shape is not None: + input_names.append("output_shape") + input_info.append(helper.make_tensor_value_info("output_shape", TensorProto.INT64, list(output_shape.shape))) + input_values.append(output_shape) + + node = helper.make_node( + "MaxUnpool", + inputs=input_names, + outputs=["y"], + kernel_shape=kernel_shape + ) + + if pads is not None: + pad_attr = helper.make_attribute('pads', pads) + node.attribute.append(pad_attr) + + if strides is not None: + strides_attr = helper.make_attribute('strides', strides) + node.attribute.append(strides_attr) + + graph = helper.make_graph( + [node], + "maxunpool_test", + inputs=input_info, + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, [])], + ) + + model = helper.make_model(graph, producer_name="size_test") + + verify_with_ort_with_inputs(model, input_values, use_vm=True, opset=11) + + xT = np.array([[[[5, 6], + [7, 8]]]], dtype=np.float32) + xI = np.array([[[[5, 7], + [13, 15]]]], dtype=np.int64) + verify_maxunpool(xT, xI, [2, 2], strides=[2, 2]) + + if __name__ == "__main__": - test_flatten() - test_reshape() - test_shape() - test_expand() - test_power() - test_squeeze() - test_unsqueeze() - test_slice() - test_floor() - test_ceil() - test_round() - test_isinf() - test_isnan() - test_clip() - test_clip_min_max_as_inputs() - test_onehot() - test_matmul() - test_gather() - test_gatherelements() - test_gather_nd() - test_scatter() - test_lrn() - test_instance_norm() - test_upsample() - test_forward_min() - test_forward_max() - test_forward_mean() - test_forward_hardsigmoid() - test_forward_arg_min_max() - test_softmax() - test_constantofshape() - test_all_reduce_funcs() - test_pad() - test_split() - test_binary_ops() - test_single_ops() - test_leaky_relu() - test_elu() - test_selu() - test_prelu() - test_ThresholdedRelu() - test_ScaledTanh() - test_ParametricSoftplus() - test_Scale() - test_LogSoftmax() - test_resnet() - test_inception() - test_densenet() - test_sign() - test_not() - test_and() - test_tile() - test_erf() - test_where() - test_or() - test_depth_to_space() - test_space_to_depth() - test_batch_norm() - test_batch_norm_dynamic_subgraph() - test_conv() - test_convtranspose() - test_unsqueeze_constant() - test_pooling() - test_lppool() - test_lstm() - test_gru() - test_resize() - test_nonzero() - test_topk() - test_mod() - test_xor() - test_max_roi_pool() - test_roi_align() - test_range() - test_loop() - test_size() + test_maxunpool() + #test_flatten() + #test_reshape() + #test_shape() + #test_expand() + #test_power() + #test_squeeze() + #test_unsqueeze() + #test_slice() + #test_floor() + #test_ceil() + #test_round() + #test_isinf() + #test_isnan() + #test_clip() + #test_clip_min_max_as_inputs() + #test_onehot() + #test_matmul() + #test_gather() + #test_gatherelements() + #test_gather_nd() + #test_scatter() + #test_lrn() + #test_instance_norm() + #test_upsample() + #test_forward_min() + #test_forward_max() + #test_forward_mean() + #test_forward_hardsigmoid() + #test_forward_arg_min_max() + #test_softmax() + #test_constantofshape() + #test_all_reduce_funcs() + #test_pad() + #test_split() + #test_binary_ops() + #test_single_ops() + #test_leaky_relu() + #test_elu() + #test_selu() + #test_prelu() + #test_ThresholdedRelu() + #test_ScaledTanh() + #test_ParametricSoftplus() + #test_Scale() + #test_LogSoftmax() + #test_resnet() + #test_inception() + #test_densenet() + #test_sign() + #test_not() + #test_and() + #test_tile() + #test_erf() + #test_where() + #test_or() + #test_depth_to_space() + #test_space_to_depth() + #test_batch_norm() + #test_batch_norm_dynamic_subgraph() + #test_conv() + #test_convtranspose() + #test_unsqueeze_constant() + #test_pooling() + #test_lppool() + #test_lstm() + #test_gru() + #test_resize() + #test_nonzero() + #test_topk() + #test_mod() + #test_xor() + #test_max_roi_pool() + #test_roi_align() + #test_range() + #test_loop() + #test_size() From 04d8037d0a4270bb327720791a481e08631c31b8 Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Fri, 4 Dec 2020 10:57:06 -0800 Subject: [PATCH 2/4] MaxUnpool implemented and tested. --- python/tvm/relay/frontend/onnx.py | 61 +++++++ tests/python/frontend/onnx/test_forward.py | 190 +++++++++++---------- 2 files changed, 162 insertions(+), 89 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index d65f5676fb33..00a1d8ba96a7 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -622,6 +622,66 @@ class MaxPool(Pool): name = "max_pool" +class MaxUnpool(OnnxOpConverter): + """Operator converter for MaxUnpool""" + + @classmethod + def _impl_v11(cls, inputs, attr, params): + # Unpack inputs and attributes + data = inputs[0] + data_type = infer_type(data).checked_type.dtype + indices = inputs[1] + output_shape = inputs[2] + kernel_shape = attr.get("kernel_shape") + pads = attr.get("pads", None) + strides = attr.get("strides", [1] * len(kernel_shape)) + + # Compute the proper output shape before padding. + multiplier = _op.concatenate( + [_expr.const([1, 1], dtype="int64"), _expr.const(list(strides), dtype="int64")], axis=0 + ) + total_output_shape = multiplier * _op.shape_of(data, dtype="int64") + # Add extra dimensions from kernel size and stride mismatch + total_output_shape += _op.concatenate( + [_expr.const([0, 0], "int64"), _expr.const(list(kernel_shape), "int64")], axis=0 + ) - _op.concatenate( + [_expr.const([0, 0], "int64"), _expr.const(list(strides), "int64")], axis=0 + ) + + # Compute padding amount if output shape is specified. Note that this ignores the pads attribute. + if output_shape is not None: + # Compute total extra values to add + total_pad = _op.maximum( + _expr.const(0, dtype="int64"), output_shape - total_output_shape + ) + total_output_shape = total_output_shape + total_pad + + elif pads is not None: + # Get pads in the proper format for relay. + pads = _op.concatenate( + [_expr.const([0, 0, 0, 0], "int64"), _expr.const(list(pads), "int64")], axis=0 + ) + pads = _op.reshape(pads, [-1, 2]) + # Compute the total padding per axis. + total_pad = _op.sum(pads, axis=-1) + # Reversing maxpool means that padding actually makes our output smaller. + total_output_shape = total_output_shape - total_pad + + # Create a tensor of zeros then scatter our data through it. + zeros_tensor = _op.zeros(total_output_shape, data_type) + # We need to flatten all our tensors before scattering. + flat_tensor = _op.scatter( + _op.reshape(zeros_tensor, [-1]), + _op.reshape(indices, [-1]), + _op.reshape(data, [-1]), + axis=0, + ) + # Now reshape back to prepadded shape. + output_tensor = _op.reshape(flat_tensor, total_output_shape) + + return output_tensor + + class LpPool(OnnxOpConverter): """A helper class for lppool op converters.""" @@ -2330,6 +2390,7 @@ def _get_convert_map(opset): "AveragePool": AveragePool.get_converter(opset), "LpPool": LpPool.get_converter(opset), "MaxPool": MaxPool.get_converter(opset), + "MaxUnpool": MaxUnpool.get_converter(opset), "Conv": Conv.get_converter(opset), "ConvTranspose": ConvTranspose.get_converter(opset), "GlobalAveragePool": Renamer("global_avg_pool2d"), diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index a93c76858393..626eba945f2e 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -3919,27 +3919,30 @@ def verify_size(indata): def test_maxunpool(): def verify_maxunpool(data, indices, kernel_shape, output_shape=None, pads=None, strides=None): input_names = ["xT", "xI"] - input_info = [helper.make_tensor_value_info("xT", TensorProto.FLOAT, list(data.shape)), - helper.make_tensor_value_info("xI", TensorProto.INT64, list(indices.shape))] + input_info = [ + helper.make_tensor_value_info("xT", TensorProto.FLOAT, list(data.shape)), + helper.make_tensor_value_info("xI", TensorProto.INT64, list(indices.shape)), + ] input_values = [data, indices] if output_shape is not None: input_names.append("output_shape") - input_info.append(helper.make_tensor_value_info("output_shape", TensorProto.INT64, list(output_shape.shape))) + input_info.append( + helper.make_tensor_value_info( + "output_shape", TensorProto.INT64, list(output_shape.shape) + ) + ) input_values.append(output_shape) node = helper.make_node( - "MaxUnpool", - inputs=input_names, - outputs=["y"], - kernel_shape=kernel_shape + "MaxUnpool", inputs=input_names, outputs=["y"], kernel_shape=kernel_shape ) if pads is not None: - pad_attr = helper.make_attribute('pads', pads) + pad_attr = helper.make_attribute("pads", pads) node.attribute.append(pad_attr) if strides is not None: - strides_attr = helper.make_attribute('strides', strides) + strides_attr = helper.make_attribute("strides", strides) node.attribute.append(strides_attr) graph = helper.make_graph( @@ -3953,88 +3956,97 @@ def verify_maxunpool(data, indices, kernel_shape, output_shape=None, pads=None, verify_with_ort_with_inputs(model, input_values, use_vm=True, opset=11) - xT = np.array([[[[5, 6], - [7, 8]]]], dtype=np.float32) - xI = np.array([[[[5, 7], - [13, 15]]]], dtype=np.int64) + # Basic test + xT = np.array([[[[5, 6], [7, 8]]]], dtype=np.float32) + xI = np.array([[[[0, 7], [13, 15]]]], dtype=np.int64) verify_maxunpool(xT, xI, [2, 2], strides=[2, 2]) + # Small stride + verify_maxunpool(xT, xI, [2, 2], strides=[1, 1]) + # Big kernel + verify_maxunpool(xT, xI, [3, 3], strides=[2, 2]) + # With output shape + output_shape = np.array((1, 1, 5, 5), dtype=np.int64) + verify_maxunpool(xT, xI, [2, 2], strides=[2, 2], output_shape=output_shape) + # With explicit reverse padding + pads = np.asarray([1, 1, 1, 1]).astype(np.int64) + verify_maxunpool(xT, xI, [2, 2], strides=[2, 2], pads=pads) if __name__ == "__main__": + test_flatten() + test_reshape() + test_shape() + test_expand() + test_power() + test_squeeze() + test_unsqueeze() + test_slice() + test_floor() + test_ceil() + test_round() + test_isinf() + test_isnan() + test_clip() + test_clip_min_max_as_inputs() + test_onehot() + test_matmul() + test_gather() + test_gatherelements() + test_gather_nd() + test_scatter() + test_lrn() + test_instance_norm() + test_upsample() + test_forward_min() + test_forward_max() + test_forward_mean() + test_forward_hardsigmoid() + test_forward_arg_min_max() + test_softmax() + test_constantofshape() + test_all_reduce_funcs() + test_pad() + test_split() + test_binary_ops() + test_single_ops() + test_leaky_relu() + test_elu() + test_selu() + test_prelu() + test_ThresholdedRelu() + test_ScaledTanh() + test_ParametricSoftplus() + test_Scale() + test_LogSoftmax() + test_resnet() + test_inception() + test_densenet() + test_sign() + test_not() + test_and() + test_tile() + test_erf() + test_where() + test_or() + test_depth_to_space() + test_space_to_depth() + test_batch_norm() + test_batch_norm_dynamic_subgraph() + test_conv() + test_convtranspose() + test_unsqueeze_constant() + test_pooling() + test_lppool() + test_lstm() + test_gru() + test_resize() + test_nonzero() + test_topk() + test_mod() + test_xor() + test_max_roi_pool() + test_roi_align() + test_range() + test_loop() + test_size() test_maxunpool() - #test_flatten() - #test_reshape() - #test_shape() - #test_expand() - #test_power() - #test_squeeze() - #test_unsqueeze() - #test_slice() - #test_floor() - #test_ceil() - #test_round() - #test_isinf() - #test_isnan() - #test_clip() - #test_clip_min_max_as_inputs() - #test_onehot() - #test_matmul() - #test_gather() - #test_gatherelements() - #test_gather_nd() - #test_scatter() - #test_lrn() - #test_instance_norm() - #test_upsample() - #test_forward_min() - #test_forward_max() - #test_forward_mean() - #test_forward_hardsigmoid() - #test_forward_arg_min_max() - #test_softmax() - #test_constantofshape() - #test_all_reduce_funcs() - #test_pad() - #test_split() - #test_binary_ops() - #test_single_ops() - #test_leaky_relu() - #test_elu() - #test_selu() - #test_prelu() - #test_ThresholdedRelu() - #test_ScaledTanh() - #test_ParametricSoftplus() - #test_Scale() - #test_LogSoftmax() - #test_resnet() - #test_inception() - #test_densenet() - #test_sign() - #test_not() - #test_and() - #test_tile() - #test_erf() - #test_where() - #test_or() - #test_depth_to_space() - #test_space_to_depth() - #test_batch_norm() - #test_batch_norm_dynamic_subgraph() - #test_conv() - #test_convtranspose() - #test_unsqueeze_constant() - #test_pooling() - #test_lppool() - #test_lstm() - #test_gru() - #test_resize() - #test_nonzero() - #test_topk() - #test_mod() - #test_xor() - #test_max_roi_pool() - #test_roi_align() - #test_range() - #test_loop() - #test_size() From d8120abb68022f2ed873677fd2ac1af0044219a1 Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Fri, 4 Dec 2020 11:39:24 -0800 Subject: [PATCH 3/4] Lint fix. --- python/tvm/relay/frontend/onnx.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 00a1d8ba96a7..cb8ae5762ad0 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -648,7 +648,7 @@ def _impl_v11(cls, inputs, attr, params): [_expr.const([0, 0], "int64"), _expr.const(list(strides), "int64")], axis=0 ) - # Compute padding amount if output shape is specified. Note that this ignores the pads attribute. + # Compute padding amount if output shape is specified. if output_shape is not None: # Compute total extra values to add total_pad = _op.maximum( From 55cc4af4f327379bebc28581db1e8a78f82156d8 Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Mon, 7 Dec 2020 17:09:30 -0800 Subject: [PATCH 4/4] Add explicit output shape in tests. --- python/tvm/relay/frontend/onnx.py | 6 +----- tests/python/frontend/onnx/test_forward.py | 15 +++++++++++++-- 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index cb8ae5762ad0..0b6ebdb5d5c2 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -650,11 +650,7 @@ def _impl_v11(cls, inputs, attr, params): # Compute padding amount if output shape is specified. if output_shape is not None: - # Compute total extra values to add - total_pad = _op.maximum( - _expr.const(0, dtype="int64"), output_shape - total_output_shape - ) - total_output_shape = total_output_shape + total_pad + total_output_shape = output_shape elif pads is not None: # Get pads in the proper format for relay. diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 626eba945f2e..1e0b729cbef0 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -3917,7 +3917,7 @@ def verify_size(indata): @tvm.testing.uses_gpu def test_maxunpool(): - def verify_maxunpool(data, indices, kernel_shape, output_shape=None, pads=None, strides=None): + def verify_maxunpool(data, indices, kernel_shape, strides, output_shape=None, pads=None): input_names = ["xT", "xI"] input_info = [ helper.make_tensor_value_info("xT", TensorProto.FLOAT, list(data.shape)), @@ -3932,6 +3932,17 @@ def verify_maxunpool(data, indices, kernel_shape, output_shape=None, pads=None, ) ) input_values.append(output_shape) + else: + # Compute expected output shape + output_shape = np.asarray(([1, 1] + list(strides))) * np.asarray(list(data.shape)) + output_shape += np.asarray(([0, 0] + list(kernel_shape))) - np.asarray( + ([0, 0] + list(strides)) + ) + if pads is not None: + output_shape -= np.asarray( + [0, 0] + list(np.sum(np.reshape(list(pads), [-1, 2]), axis=-1)) + ) + output_shape = [int(i) for i in output_shape] node = helper.make_node( "MaxUnpool", inputs=input_names, outputs=["y"], kernel_shape=kernel_shape @@ -3949,7 +3960,7 @@ def verify_maxunpool(data, indices, kernel_shape, output_shape=None, pads=None, [node], "maxunpool_test", inputs=input_info, - outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, [])], + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, output_shape)], ) model = helper.make_model(graph, producer_name="size_test")