From 1ef667b1428fb49a1b3dab8846c44b813f79cb96 Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Thu, 24 Jun 2021 00:53:11 +0000 Subject: [PATCH 1/2] Enable group conv1d import through conv2d hack. --- python/tvm/relay/frontend/onnx.py | 28 +++- tests/python/frontend/onnx/test_forward.py | 182 +++++++++++---------- 2 files changed, 125 insertions(+), 85 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index c8855b2ea2be..5d07102f2c3f 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -451,6 +451,7 @@ class Conv(OnnxOpConverter): def _impl_v1(cls, inputs, attr, params): # Use shape of input to determine convolution type. data = inputs[0] + kernel = inputs[1] input_shape = infer_shape(data) ndim = len(input_shape) @@ -473,13 +474,32 @@ def _impl_v1(cls, inputs, attr, params): mode=attr["auto_pad"], ) elif attr["auto_pad"] == "VALID": - attr["pads"] = tuple([0 for i in range(ndim - 2)]) + attr["pads"] = [0 for i in range(ndim - 2)] elif attr["auto_pad"] == "NOTSET": pass else: msg = 'Value {} in attribute "auto_pad" of operator Conv is invalid.' raise tvm.error.OpAttributeInvalid(msg.format(attr["auto_pad"])) attr.pop("auto_pad") + + # Check if the requested convolution is a group conv1d, if so convert it to conv2d. + # TODO(jwfromm) Remove once proper group_conv1d is supported. + group_conv1d = False + if dimension_picker("conv")(attr) == "conv1d" and attr.get("group") != 1: + group_conv1d = True + # Expand input from NCW to NCHW + data = _op.expand_dims(data, axis=2) + # Expand kernel from OIW to OIHW + kernel = _op.expand_dims(kernel, axis=2) + # Add new value to kernel_shape, strices, dilation, pads, if needed + attr["kernel_shape"] = [1] + list(attr["kernel_shape"]) + if "strides" in attr: + attr["strides"] = [1] + list(attr["strides"]) + if "dilations" in attr: + attr["dilations"] = [1] + list(attr["dilations"]) + if "pads" in attr: + attr["pads"] = [0, attr["pads"][0], 0, attr["pads"][1]] + out = AttrCvt( op_name=dimension_picker("conv"), transforms={ @@ -489,7 +509,11 @@ def _impl_v1(cls, inputs, attr, params): "group": ("groups", 1), }, custom_check=dimension_constraint(), - )([data, inputs[1]], attr, params) + )([data, kernel], attr, params) + + # If this was a group_conv1d, squish output back to NCW. + if group_conv1d: + out = _op.squeeze(out, axis=[2]) use_bias = len(inputs) == 3 if use_bias: diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 6ac747c5ea94..aa18fe73551d 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -2410,6 +2410,7 @@ def verify_conv( kernel_shape, strides, dilations, + group=1, auto_pad="NOTSET", unset_pad=False, ): @@ -2422,7 +2423,7 @@ def verify_conv( # Default values for other attributes: strides=strides, dilations=dilations, - # groups=1 + group=group, ) elif padding is None: ## autopadding with unset default attributes @@ -2438,6 +2439,7 @@ def verify_conv( outputs=["y"], # Default values for other attributes: auto_pad=auto_pad, + group=group, **kwargs, ) else: @@ -2449,7 +2451,7 @@ def verify_conv( # Default values for other attributes: strides=strides, dilations=dilations, - # groups=1 + group=group, pads=padding, ) @@ -2559,6 +2561,20 @@ def repeat(N, D): repeat(2, D), ) + # TODO(jwfromm): Merge with other tests once group_conv3d is supported. + for D in [1, 2]: + # Group Convolution + verify_conv( + (1, 8) + repeat(5, D), + (8, 1) + repeat(3, D), + (1, 8) + repeat(5, D), + 2 * repeat(1, D), + repeat(3, D), + repeat(1, D), + repeat(1, D), + group=8, + ) + def verify_convtranspose_with_padding( x_shape, @@ -4641,85 +4657,85 @@ def repeat(N, D): if __name__ == "__main__": - test_flatten() - test_reshape() - test_shape() - test_expand() - test_power() - test_squeeze() - test_unsqueeze() - test_slice() - test_floor() - test_ceil() - test_round() - test_isinf() - test_isnan() - test_clip() - test_clip_min_max_as_inputs() - test_onehot() - test_gemm() - test_matmul() - test_gather() - test_gatherelements() - test_gather_nd() - test_scatter() - test_lrn() - test_instance_norm() - test_upsample() - test_forward_min() - test_forward_max() - test_forward_mean() - test_forward_hardsigmoid() - test_forward_arg_min_max() - test_softmax() - test_constantofshape() - test_all_reduce_funcs() - test_pad() - test_split() - test_binary_ops() - test_unary_ops() - test_leaky_relu() - test_elu() - test_selu() - test_prelu() - test_ThresholdedRelu() - test_LogSoftmax() - test_resnet() - test_inception() - test_densenet() - test_sign() - test_not() - test_and() - test_tile() - test_erf() - test_where() - test_or() - test_depth_to_space() - test_space_to_depth() - test_batch_norm() - test_batch_norm_dynamic_subgraph() + # test_flatten() + # test_reshape() + # test_shape() + # test_expand() + # test_power() + # test_squeeze() + # test_unsqueeze() + # test_slice() + # test_floor() + # test_ceil() + # test_round() + # test_isinf() + # test_isnan() + # test_clip() + # test_clip_min_max_as_inputs() + # test_onehot() + # test_gemm() + # test_matmul() + # test_gather() + # test_gatherelements() + # test_gather_nd() + # test_scatter() + # test_lrn() + # test_instance_norm() + # test_upsample() + # test_forward_min() + # test_forward_max() + # test_forward_mean() + # test_forward_hardsigmoid() + # test_forward_arg_min_max() + # test_softmax() + # test_constantofshape() + # test_all_reduce_funcs() + # test_pad() + # test_split() + # test_binary_ops() + # test_unary_ops() + # test_leaky_relu() + # test_elu() + # test_selu() + # test_prelu() + # test_ThresholdedRelu() + # test_LogSoftmax() + # test_resnet() + # test_inception() + # test_densenet() + # test_sign() + # test_not() + # test_and() + # test_tile() + # test_erf() + # test_where() + # test_or() + # test_depth_to_space() + # test_space_to_depth() + # test_batch_norm() + # test_batch_norm_dynamic_subgraph() test_conv() - test_convtranspose() - test_unsqueeze_constant() - test_pooling() - test_lppool() - test_lstm() - test_gru() - test_resize() - test_nonzero() - test_topk() - test_mod() - test_xor() - test_max_roi_pool() - test_roi_align() - test_range() - test_loop() - test_size() - test_maxunpool() - test_softplus() - test_cumsum() - test_wrong_input() - test_aten() - test_reverse_sequence() - test_eyelike() - test_qlinearconv() + # test_convtranspose() + # test_unsqueeze_constant() + # test_pooling() + # test_lppool() + # test_lstm() + # test_gru() + # test_resize() + # test_nonzero() + # test_topk() + # test_mod() + # test_xor() + # test_max_roi_pool() + # test_roi_align() + # test_range() + # test_loop() + # test_size() + # test_maxunpool() + # test_softplus() + # test_cumsum() + # test_wrong_input() + # test_aten() + # test_reverse_sequence() + # test_eyelike() + # test_qlinearconv() From 6654f7de5b1749428c7263782e96878eb5a7a355 Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Thu, 24 Jun 2021 01:41:45 +0000 Subject: [PATCH 2/2] remove silly commented out lines. --- tests/python/frontend/onnx/test_forward.py | 162 ++++++++++----------- 1 file changed, 81 insertions(+), 81 deletions(-) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index aa18fe73551d..2f92f2d51994 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -4657,85 +4657,85 @@ def repeat(N, D): if __name__ == "__main__": - # test_flatten() - # test_reshape() - # test_shape() - # test_expand() - # test_power() - # test_squeeze() - # test_unsqueeze() - # test_slice() - # test_floor() - # test_ceil() - # test_round() - # test_isinf() - # test_isnan() - # test_clip() - # test_clip_min_max_as_inputs() - # test_onehot() - # test_gemm() - # test_matmul() - # test_gather() - # test_gatherelements() - # test_gather_nd() - # test_scatter() - # test_lrn() - # test_instance_norm() - # test_upsample() - # test_forward_min() - # test_forward_max() - # test_forward_mean() - # test_forward_hardsigmoid() - # test_forward_arg_min_max() - # test_softmax() - # test_constantofshape() - # test_all_reduce_funcs() - # test_pad() - # test_split() - # test_binary_ops() - # test_unary_ops() - # test_leaky_relu() - # test_elu() - # test_selu() - # test_prelu() - # test_ThresholdedRelu() - # test_LogSoftmax() - # test_resnet() - # test_inception() - # test_densenet() - # test_sign() - # test_not() - # test_and() - # test_tile() - # test_erf() - # test_where() - # test_or() - # test_depth_to_space() - # test_space_to_depth() - # test_batch_norm() - # test_batch_norm_dynamic_subgraph() + test_flatten() + test_reshape() + test_shape() + test_expand() + test_power() + test_squeeze() + test_unsqueeze() + test_slice() + test_floor() + test_ceil() + test_round() + test_isinf() + test_isnan() + test_clip() + test_clip_min_max_as_inputs() + test_onehot() + test_gemm() + test_matmul() + test_gather() + test_gatherelements() + test_gather_nd() + test_scatter() + test_lrn() + test_instance_norm() + test_upsample() + test_forward_min() + test_forward_max() + test_forward_mean() + test_forward_hardsigmoid() + test_forward_arg_min_max() + test_softmax() + test_constantofshape() + test_all_reduce_funcs() + test_pad() + test_split() + test_binary_ops() + test_unary_ops() + test_leaky_relu() + test_elu() + test_selu() + test_prelu() + test_ThresholdedRelu() + test_LogSoftmax() + test_resnet() + test_inception() + test_densenet() + test_sign() + test_not() + test_and() + test_tile() + test_erf() + test_where() + test_or() + test_depth_to_space() + test_space_to_depth() + test_batch_norm() + test_batch_norm_dynamic_subgraph() test_conv() - # test_convtranspose() - # test_unsqueeze_constant() - # test_pooling() - # test_lppool() - # test_lstm() - # test_gru() - # test_resize() - # test_nonzero() - # test_topk() - # test_mod() - # test_xor() - # test_max_roi_pool() - # test_roi_align() - # test_range() - # test_loop() - # test_size() - # test_maxunpool() - # test_softplus() - # test_cumsum() - # test_wrong_input() - # test_aten() - # test_reverse_sequence() - # test_eyelike() - # test_qlinearconv() + test_convtranspose() + test_unsqueeze_constant() + test_pooling() + test_lppool() + test_lstm() + test_gru() + test_resize() + test_nonzero() + test_topk() + test_mod() + test_xor() + test_max_roi_pool() + test_roi_align() + test_range() + test_loop() + test_size() + test_maxunpool() + test_softplus() + test_cumsum() + test_wrong_input() + test_aten() + test_reverse_sequence() + test_eyelike() + test_qlinearconv()