From d3a71ed3e518f5b14339f521840437fce70e0301 Mon Sep 17 00:00:00 2001 From: zha0q1 Date: Wed, 14 Apr 2021 21:12:30 +0000 Subject: [PATCH 1/5] legacy operator unit tests + fixes --- .../_op_translations_opset12.py | 260 ++++++++++-------- tests/python-pytest/onnx/test_operators.py | 173 ++++++++++++ 2 files changed, 315 insertions(+), 118 deletions(-) diff --git a/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset12.py b/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset12.py index d683aad7000c..ec24d6574380 100644 --- a/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset12.py +++ b/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset12.py @@ -171,6 +171,8 @@ def create_const_scalar_node(input_name, value, kwargs): from onnx.helper import make_tensor initializer = kwargs["initializer"] input_type = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[value.dtype] + if value.dtype == np.float16: + value = value.view(dtype=np.uint16) tensor_node = make_tensor(input_name, input_type, (), ([value])) initializer.append(tensor_node) @@ -276,67 +278,95 @@ def convert_convolution(node, **kwargs): return nodes -@mx_op.register("Deconvolution") +@mx_op.register('Deconvolution') def convert_deconvolution(node, **kwargs): """Map MXNet's deconvolution operator attributes to onnx's ConvTranspose operator and return the created node. """ - name, inputs, attrs = get_inputs(node, kwargs) + name, input_nodes, attrs = get_inputs(node, kwargs) + + kernel_shape = convert_string_to_list(attrs.get('kernel', '()')) + strides = convert_string_to_list(attrs.get('stride', '()')) + pads = convert_string_to_list(attrs.get('pad', '()')) + group = int(attrs.get("num_group", 1)) + dilations = convert_string_to_list(attrs.get('dilate', '()')) + output_padding = convert_string_to_list(attrs.get('adj', '()')) + layout = attrs.get('layout', 'NCHW') + target_shape = attrs.get('target_shape', '') + no_bias = attrs.get('no_bias', 'False') + + pads = pads + pads - kernel_dims = list(parse_helper(attrs, "kernel")) - stride_dims = list(parse_helper(attrs, "stride", [1, 1])) - pad_dims = list(parse_helper(attrs, "pad", [0, 0])) - num_group = int(attrs.get("num_group", 1)) - dilations = list(parse_helper(attrs, "dilate", [1, 1])) - adj_dims = list(parse_helper(attrs, "adj", [0, 0])) + if target_shape != '': + raise NotImplementedError('Deconvolution currently does not support target_shape') - pad_dims = pad_dims + pad_dims + if layout not in ['NCHW', 'NCDHW', 'NCW']: + raise NotImplementedError('Deconvolution currently does not support layout not in ' + '[\'NCHW\', \'NCDHW\', \'NCW\']') + + if no_bias == 'True': + assert len(input_nodes) == 2, 'Deconvolution takes 2 input if no_bias==True' + else: + assert len(input_nodes) == 3, 'Deconvolution takes 3 input if no_bias==False' + + kwargs_ = {} + if kernel_shape: + kwargs_['kernel_shape'] = kernel_shape + if pads: + kwargs_['pads'] = pads + if strides: + kwargs_['strides'] = strides + if dilations: + kwargs_['dilations'] = dilations + if output_padding: + kwargs_['output_padding'] = output_padding deconv_node = onnx.helper.make_node( "ConvTranspose", - inputs=inputs, + inputs=input_nodes, outputs=[name], - kernel_shape=kernel_dims, - strides=stride_dims, - dilations=dilations, - output_padding=adj_dims, - pads=pad_dims, - group=num_group, - name=name + group=group, + **kwargs_ ) return [deconv_node] -@mx_op.register("Crop") +@mx_op.register('Crop') def convert_crop(node, **kwargs): - """Map MXNet's crop operator attributes to onnx's Crop operator - and return the created node. + """Map MXNet's crop operator attributes to onnx's Slice operator """ + from onnx.helper import make_node name, inputs, attrs = get_inputs(node, kwargs) - num_inputs = len(inputs) - y, x = list(parse_helper(attrs, "offset", [0, 0])) - h, w = list(parse_helper(attrs, "h_w", [0, 0])) - if num_inputs > 1: - h, w = kwargs["out_shape"][-2:] - border = [x, y, x + w, y + h] + num_inputs = len(inputs) + y, x = convert_string_to_list(attrs.get('offset', '(0, 0)')) + h, w = convert_string_to_list(attrs.get('h_w', '(0, 0)')) + center_crop = attrs.get('center_crop', 'False') - crop_node = onnx.helper.make_node( - "Crop", - inputs=[inputs[0]], - outputs=[name], - border=border, - scale=[1, 1], - name=name - ) + if center_crop in ['True', '1']: + raise NotImplementedError('Crop does not currently support center_crop==True') - logging.warning( - "Using an experimental ONNX operator: Crop. " \ - "Its definition can change.") + nodes = [] + create_tensor([y, x], name+'_starts', kwargs['initializer']) + create_tensor([2, 3], name+'_axes', kwargs['initializer']) + if num_inputs == 1: + create_tensor([y + h, x + w], name+'_ends', kwargs['initializer']) + else: + create_tensor([0], name+'_0', kwargs['initializer']) + create_tensor([2], name+'_2', kwargs['initializer']) + create_tensor([4], name+'_4', kwargs['initializer']) + nodes += [ + make_node('Shape', [inputs[1]], [name+'_shape']), + make_node('Slice', [name+'_shape', name+'_2', name+'_4', name+'_0'], [name+'_h_w']), + make_node('Add', [name+'_starts', name+'_h_w'], [name+'_ends']) - return [crop_node] + ] + nodes += [ + make_node('Slice', [inputs[0], name+'_starts', name+'_ends', name+'_axes'], [name]) + ] + return nodes @mx_op.register("FullyConnected") def convert_fully_connected(node, **kwargs): @@ -529,12 +559,16 @@ def convert_pad(node, **kwargs): from onnx.helper import make_node opset_version = kwargs["opset_version"] name, input_nodes, attrs = get_inputs(node, kwargs) + input_dtypes = get_input_dtypes(node, kwargs) + + dtype = input_dtypes[0] mxnet_pad_width = convert_string_to_list(attrs.get("pad_width")) onnx_pad_width = transform_padding(mxnet_pad_width) pad_mode = attrs.get("mode") - pad_value = np.float32(attrs.get("constant_value", 0.0)) + pad_value = float(attrs.get("constant_value", 0.0)) + pad_value = dtype.type(pad_value) if opset_version >= 11: # starting with opset 11, pads and constant_value are inputs instead of attributes @@ -584,31 +618,58 @@ def create_helper_trans_node(node_name, input_node): return trans_node +# Note that due to ONNX limitation, the behavior for when inputs > 2-D is different from that of +# MXNet @mx_op.register("dot") def convert_dot(node, **kwargs): """Map MXNet's dot operator attributes to onnx's MatMul and Transpose operators based on the values set for transpose_a, transpose_b attributes.""" - name, input_nodes, attrs = get_inputs(node, kwargs) - + logging.warning('Converting dot operator... Please note that due to ONNX limitation, the ' + 'behavior for when inputs > 2-D is different from that of MXNet dot.') + + name, inputs, attrs = get_inputs(node, kwargs) trans_a = get_boolean_attribute_value(attrs, "transpose_a") trans_b = get_boolean_attribute_value(attrs, "transpose_b") - + nodes = [] input_nodes = [] if trans_a: - nodes.append(create_helper_trans_node(name+"_a", input_nodes[0])) + nodes.append(create_helper_trans_node(name+"_a", inputs[0])) input_nodes.append(name+"_a") else: - input_nodes.append(input_nodes[0]) + input_nodes.append(inputs[0]) if trans_b: - nodes.append(create_helper_trans_node(name+"_b", input_nodes[1])) + nodes.append(create_helper_trans_node(name+"_b", inputs[1])) input_nodes.append(name+"_b") else: - input_nodes.append(input_nodes[1]) + input_nodes.append(inputs[1]) + + nodes.append(onnx.helper.make_node('MatMul', input_nodes, [name], name=name)) + return nodes - nodes.appennd(onnx.helper.make_node('MatMul', input_nodes, [name], name=name)) + +def transpose_last_two_dim(name, kwargs): + from onnx.helper import make_node + create_tensor([0], name+'_0', kwargs['initializer']) + create_tensor([1], name+'_1', kwargs['initializer']) + create_tensor([10], name+'_10', kwargs['initializer']) + perm = [i for i in range(10)] + perm[8], perm[9] = 9, 8 + nodes = [ + make_node('Shape', [name], [name+'_shape']), + make_node('Shape', [name+'_shape'], [name+'_dim']), + make_node('Sub', [name+'_10', name+'_dim'], [name+'_sub']), + make_node('Concat', [name+'_sub', name+'_0'], [name+'_concat'], axis=0), + make_node('Pad', [name+'_shape', name+'_concat', name+'_1'], [name+'_shape_10_dim']), + make_node('Reshape', [name, name+'_shape_10_dim'], [name+'_data_10_dim']), + make_node('Transpose', [name+'_data_10_dim'], [name+'_data_t'], perm=perm), + make_node('Shape', [name+'_data_t'], [name+'_new_shape_']), + make_node('Slice', [name+'_new_shape_', name+'_sub', name+'_10', name+'_0'], + [name+'_new_shape']), + make_node('Reshape', [name+'_data_t', name+'_new_shape'], [name+'_transposed']), + ] return nodes @@ -619,84 +680,47 @@ def convert_linalg_gemm2(node, **kwargs): transpose_a, transpose_b attributes. Return multiple nodes created. """ - name, input_nodes, attrs = get_inputs(node, kwargs) - - # Getting the attributes and assigning default values. - alpha = float(attrs.get("alpha", 1.0)) - trans_a = get_boolean_attribute_value(attrs, "transpose_a") - trans_b = get_boolean_attribute_value(attrs, "transpose_b") - - op_name = "transpose" + str(kwargs["idx"]) + from onnx.helper import make_node + name, inputs, attrs = get_inputs(node, kwargs) + input_dtypes = get_input_dtypes(node, kwargs) - if alpha == 1.0 and trans_a == 0 and trans_b == 0: - matmul_node = onnx.helper.make_node( - 'MatMul', - inputs=input_nodes, - outputs=[name], - name=name - ) - return [matmul_node] - elif trans_a == 1 and trans_b == 0: - op_name = "transpose" + str(kwargs["idx"]) - node_name = op_name+"_a" - trans_a_node = onnx.helper.make_node( - 'Transpose', - inputs=[input_nodes[0]], - outputs=[op_name+"_a"], - name=node_name - ) + dtype = input_dtypes[0] - matmul_node = onnx.helper.make_node( - 'MatMul', - inputs=[node_name, input_nodes[1]], - outputs=[name], - name=name - ) - return [trans_a_node, matmul_node] - - elif trans_a == 0 and trans_b == 1: - node_name = op_name + "_b" - trans_b_node = onnx.helper.make_node( - 'Transpose', - inputs=[input_nodes[1]], - outputs=[op_name+"_b"], - name=node_name - ) + # Getting the attributes and assigning default values. + alpha = float(attrs.get('alpha', 1.0)) + axis = attrs.get('axis', 'None') + trans_a = get_boolean_attribute_value(attrs, 'transpose_a') + trans_b = get_boolean_attribute_value(attrs, 'transpose_b') - matmul_node = onnx.helper.make_node( - 'MatMul', - inputs=[input_nodes[0], node_name], - outputs=[name], - name=name - ) + if axis != 'None': + raise NotImplementedError('_linalg_gemm2 does not currently support axis!=None') - return [trans_b_node, matmul_node] + nodes = [] + input_nodes = [] + if trans_a: + nodes += transpose_last_two_dim(inputs[0], kwargs) + input_nodes.append(inputs[0]+'_transposed') else: - node_name_a = op_name+"_a" - trans_a_node = onnx.helper.make_node( - 'Transpose', - inputs=[input_nodes[0]], - outputs=[op_name+"_a"], - name=node_name_a - ) - - node_name_b = op_name + "_b" - trans_b_node = onnx.helper.make_node( - 'Transpose', - inputs=[input_nodes[1]], - outputs=[op_name+"_b"], - name=node_name_b - ) + input_nodes.append(inputs[0]) - matmul_node = onnx.helper.make_node( - 'MatMul', - inputs=input_nodes, - outputs=[name], - name=name - ) - - return [trans_a_node, trans_b_node, matmul_node] + if trans_b: + nodes += transpose_last_two_dim(inputs[1], kwargs) + input_nodes.append(inputs[1]+'_transposed') + else: + input_nodes.append(inputs[1]) + if alpha == 1: + nodes += [ + make_node('MatMul', input_nodes, [name]) + ] + return nodes + + create_const_scalar_node(name+"_alpha", dtype.type(alpha), kwargs) + nodes += [ + make_node('MatMul', input_nodes, [name+'_matmul']), + make_node('Mul', [name+'_matmul', name+'_alpha'], [name]) + ] + return nodes @mx_op.register('Pooling') def convert_pooling(node, **kwargs): diff --git a/tests/python-pytest/onnx/test_operators.py b/tests/python-pytest/onnx/test_operators.py index b032fa7fc1bd..af786f8e2978 100644 --- a/tests/python-pytest/onnx/test_operators.py +++ b/tests/python-pytest/onnx/test_operators.py @@ -66,6 +66,8 @@ def onnx_rt(onnx_file, inputs): pred_onx = onnx_rt(onnx_file, inputs) if dummy_input: pred_mx = pred_mx[0] + print(pred_mx) + print(pred_onx) if isinstance(pred_mx, list): for i in range(len(pred_mx)): pred_onx_i = onnx_map(pred_onx[i]) if onnx_map else pred_onx[i] @@ -1275,3 +1277,174 @@ def test_onnx_export_contrib_div_sqrt_dim(tmp_path, dtype, shape): A = mx.nd.random.uniform(-100, 100, shape).astype(dtype) M = def_model('contrib.div_sqrt_dim') op_export_test('contrib_div_sqrt_dim', M, [A], tmp_path) + + +@pytest.mark.parametrize('dtype', ['float16', 'float32']) +@pytest.mark.parametrize('shape', [(3, 4, 5), (6, 7), (8,)]) +@pytest.mark.parametrize('operator', ['sin', 'cos', 'tan', 'tanh', 'arcsin', 'arccos', 'arctan', + 'sigmoid', 'relu', 'exp', 'identity', 'BlockGrad', 'MakeLoss']) +def test_onnx_export_ufunc(tmp_path, dtype, shape, operator): + A = mx.nd.random.uniform(-100, 100, shape).astype(dtype) + M = def_model(operator) + op_export_test('ufunc', M, [A], tmp_path) + + +@pytest.mark.parametrize('dtype', ['float32']) +@pytest.mark.parametrize('shape', [(1, 3, 64, 64), (2, 6, 60, 60)]) +@pytest.mark.parametrize('num_filter', [4, 16, 256]) +@pytest.mark.parametrize('num_group', [1, 2]) +@pytest.mark.parametrize('no_bias', [False, True]) +@pytest.mark.parametrize('kernel', [(2, 2), (3, 4)]) +@pytest.mark.parametrize('stride', [(1, 1), (2, 2)]) +@pytest.mark.parametrize('pad', [None, (0, 0), (1, 1)]) +@pytest.mark.parametrize('dilate', [None, (1, 1)]) +@pytest.mark.parametrize('adj', [(0, 0), (1, 1)]) +def test_onnx_export_deconvolution(tmp_path, dtype, shape, num_filter, num_group, no_bias, + kernel, stride, pad, dilate, adj): + for i in range(len(stride)): + if stride[i] <= adj[i]: + return + if shape[1] % num_group: + return + x = mx.random.uniform(0, 1, shape, dtype=dtype) + w_shape = (shape[1],) + (num_filter // num_group,) + kernel + w = mx.random.uniform(0, 1, w_shape, dtype=dtype) + b_shape = (num_filter) + b = mx.random.uniform(0, 1, b_shape, dtype=dtype) + kwargs = {} + if kernel: + kwargs['kernel'] = kernel + if stride: + kwargs['stride'] = stride + if pad: + kwargs['pad'] = pad + if dilate: + kwargs['dilate'] = dilate + if adj: + kwargs['adj'] = adj + M = def_model('Deconvolution', num_filter=num_filter, num_group=num_group, no_bias=no_bias, + layout='NCHW', **kwargs) + inputs = [x, w] if no_bias else [x, w, b] + op_export_test('deconvolution', M, inputs, tmp_path) + + +@pytest.mark.parametrize('dtype', ['float32', 'float16', 'float64']) +@pytest.mark.parametrize('mode', ['edge', 'constant', 'reflect']) +@pytest.mark.parametrize('params', [((3, 4, 5, 6), (0, 0, 0, 0, 2, 3, 4, 5)), + ((7, 6, 5, 4, 3), (0, 0, 0, 0, 4, 4, 3, 3, 2, 1))]) +def test_onnx_export_pad(tmp_path, dtype, mode, params): + kwargs = {} + kwargs['constant_value'] = 9999.55 + kwargs['pad_width'] = params[1] + x = mx.random.uniform(0, 1, shape=params[0], dtype=dtype) + M = def_model('pad', mode=mode, **kwargs) + op_export_test('pad', M, [x], tmp_path) + + +# Note that due to ONNX limitation, the behavior for when inputs > 2-D is different from that of +# MXNet +@pytest.mark.parametrize('dtype', ['float32', 'float64']) +@pytest.mark.parametrize('params', [((4,), (4,), False, False), + ((4, 5), (5, 6), False, False), + ((5, 4), (5, 6), True, False), + ((5, 4), (6, 5), True, True), + ((4, 5), (6, 5), False, True), + ((4, 5), (5), False, False), + ((4), (4, 5), False, False)]) +def test_onnx_export_dot(tmp_path, dtype, params): + A = mx.random.uniform(0, 1, params[0], dtype=dtype) + B = mx.random.uniform(0, 1, params[1], dtype=dtype) + M = def_model('dot', transpose_a=params[2], transpose_b=params[3]) + op_export_test('dot', M, [A, B], tmp_path) + + +@pytest.mark.parametrize('dtype', ['float32', 'float64', 'int32', 'int64']) +@pytest.mark.parametrize('shape', [(3, 4, 5, 6), (7, 8)]) +def test_onnx_export_flatten(tmp_path, dtype, shape): + x = mx.random.uniform(0, 1, shape, dtype='float32').astype(dtype) + M = def_model('flatten') + op_export_test('flatten', M, [x], tmp_path) + + +# Note that due to ONNX limitation, the behavior for when inputs > 2-D is different from that of +# MXNet +@pytest.mark.parametrize('dtype', ['float32', 'float64']) +@pytest.mark.parametrize('alpha', [1, 1.5]) +@pytest.mark.parametrize('params', [((4, 5), (5, 4), False, False), + ((4, 5, 6), (4, 6, 5), False, False), + ((4, 5, 6, 7), (4, 5, 6, 7), True, False), + ((4, 5, 6, 7), (4, 5, 6, 7), False, True), + ((4, 5, 9, 7), (4, 5, 6, 9), True, True)]) +def test_onnx_export_linalg_gemm2(tmp_path, dtype, alpha, params): + A = mx.random.uniform(0, 1, params[0], dtype=dtype) + B = mx.random.uniform(0, 1, params[1], dtype=dtype) + M = def_model('linalg.gemm2', alpha=alpha, transpose_a=params[2], transpose_b=params[3]) + op_export_test('_linalg_gemm2', M, [A, B], tmp_path) + + +@pytest.mark.parametrize('dtype', ['float32', 'float64']) +@pytest.mark.parametrize('shape', [(3, 4, 5), (6, 7), (8,)]) +def test_onnx_export_LogisticRegressionOutput(tmp_path, dtype, shape): + x = mx.random.uniform(0, 1, shape, dtype=dtype) + y = mx.nd.zeros(shape, dtype=dtype) + M = def_model('LogisticRegressionOutput') + op_export_test('LogisticRegressionOutput', M, [x, y], tmp_path) + + +@pytest.mark.parametrize('dtype', ['float32', 'float64']) +@pytest.mark.parametrize('shape', [(3, 4, 5), (6, 7), (3, 4, 5, 6, 7)]) +def test_onnx_export_SoftmaxOutput(tmp_path, dtype, shape): + x = mx.random.uniform(0, 1, shape, dtype=dtype) + y = mx.nd.zeros(shape[:-1], dtype=dtype) + M = def_model('SoftmaxOutput') + op_export_test('SoftmaxOutput', M, [x, y], tmp_path) + + +# Due to ONNX limitation, L2Normalization only supports channel mode for now +@pytest.mark.parametrize('dtype', ['float32', 'float64']) +@pytest.mark.parametrize('shape', [(3, 4, 5), (3, 4, 5, 6, 7)]) +def test_onnx_export_L2Normalization(tmp_path, dtype, shape): + x = mx.random.uniform(0, 1, shape, dtype=dtype) + M = def_model('L2Normalization', mode='channel') + op_export_test('L2Normalization', M, [x], tmp_path) + + +@pytest.mark.parametrize('dtype', ['float32']) +@pytest.mark.parametrize('shape', [(3, 4, 5), (3, 4, 5, 6, 7)]) +@pytest.mark.parametrize('eps', [0.001, 0.00001]) +def test_onnx_export_InstanceNorm(tmp_path, dtype, shape, eps): + x = mx.random.uniform(0, 1, shape, dtype=dtype) + gamma = mx.random.uniform(0, 1, shape[1:2], dtype=dtype) + beta = mx.random.uniform(0, 1, shape[1:2], dtype=dtype) + M = def_model('InstanceNorm', eps=eps) + op_export_test('InstanceNorm', M, [x, gamma, beta], tmp_path) + + +# ONNXRuntime only supports 4-D inputs +@pytest.mark.parametrize('dtype', ['float32']) +@pytest.mark.parametrize('shape', [(4, 5, 6, 7)]) +@pytest.mark.parametrize('alpha', [0.001, 0.00001]) +@pytest.mark.parametrize('beta', [0.75, 0.8]) +@pytest.mark.parametrize('knorm', [1, 2]) +@pytest.mark.parametrize('nsize', [3, 5]) +def test_onnx_export_LRN(tmp_path, dtype, shape, alpha, beta, knorm, nsize): + x = mx.random.uniform(0, 1, shape, dtype=dtype) + M = def_model('LRN', alpha=alpha, beta=beta, knorm=knorm, nsize=nsize) + op_export_test('LRN', M, [x], tmp_path) + + +@pytest.mark.parametrize('dtype', ['float32']) +@pytest.mark.parametrize('shape', [(1, 3, 224, 224), (5, 6, 64, 64)]) +@pytest.mark.parametrize('h_w', [(10, 10), (7, 11)]) +@pytest.mark.parametrize('offset', [(7, 13), (10, 10)]) +@pytest.mark.parametrize('shape2', [None, (10, 10, 16, 16)]) +def test_onnx_export_Crop(tmp_path, dtype, shape, h_w, offset, shape2): + x = mx.random.uniform(0, 1, shape, dtype=dtype) + M = def_model('Crop', h_w=h_w, offset=offset, center_crop=True) + if shape2 is not None: + y = mx.random.uniform(0, 1, shape2, dtype=dtype) + op_export_test('Crop', M, [x, y], tmp_path) + else: + op_export_test('Crop', M, [x], tmp_path) + + From 06a9cf148c10577d60f8c54c456add0b37234ff6 Mon Sep 17 00:00:00 2001 From: Zhaoqi Zhu Date: Tue, 20 Apr 2021 21:43:00 -0700 Subject: [PATCH 2/5] Update _op_translations_opset12.py --- .../_op_translations/_op_translations_opset12.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset12.py b/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset12.py index 6464e89f9092..c8020ebd594b 100644 --- a/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset12.py +++ b/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset12.py @@ -345,8 +345,8 @@ def convert_crop(node, **kwargs): name, inputs, attrs = get_inputs(node, kwargs) num_inputs = len(inputs) - y, x = convert_string_to_list(attrs.get('offset', '(0, 0)')) - h, w = convert_string_to_list(attrs.get('h_w', '(0, 0)')) + y, x = convert_string_to_list(attrs.get('offset', '(0, 0)')) # pylint: disable=unbalanced-tuple-unpacking + h, w = convert_string_to_list(attrs.get('h_w', '(0, 0)')) # pylint: disable=unbalanced-tuple-unpacking center_crop = attrs.get('center_crop', 'False') if center_crop in ['True', '1']: @@ -632,11 +632,11 @@ def convert_dot(node, **kwargs): transpose_a, transpose_b attributes.""" logging.warning('Converting dot operator... Please note that due to ONNX limitation, the ' 'behavior for when inputs > 2-D is different from that of MXNet dot.') - + name, inputs, attrs = get_inputs(node, kwargs) trans_a = get_boolean_attribute_value(attrs, "transpose_a") trans_b = get_boolean_attribute_value(attrs, "transpose_b") - + nodes = [] input_nodes = [] if trans_a: @@ -656,6 +656,8 @@ def convert_dot(node, **kwargs): def transpose_last_two_dim(name, kwargs): + """Helper function to transpose the last two dims of the input tensor + """ from onnx.helper import make_node create_tensor([0], name+'_0', kwargs['initializer']) create_tensor([1], name+'_1', kwargs['initializer']) @@ -719,7 +721,7 @@ def convert_linalg_gemm2(node, **kwargs): make_node('MatMul', input_nodes, [name]) ] return nodes - + create_const_scalar_node(name+"_alpha", dtype.type(alpha), kwargs) nodes += [ make_node('MatMul', input_nodes, [name+'_matmul']), From 2f4e1d2b77af674c7f788351293ea7ab49f9dd41 Mon Sep 17 00:00:00 2001 From: zha0q1 Date: Wed, 21 Apr 2021 18:45:50 +0000 Subject: [PATCH 3/5] fixes for onnx 1.8 --- .../_op_translations_opset13.py | 19 +++++++++++++++++++ tests/python-pytest/onnx/test_operators.py | 10 +++++----- 2 files changed, 24 insertions(+), 5 deletions(-) diff --git a/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset13.py b/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset13.py index 7925dbf85903..c77d0ba5b68e 100644 --- a/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset13.py +++ b/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset13.py @@ -1544,3 +1544,22 @@ def convert_squeeze(node, **kwargs): name=name, ) return [node] + + +@mx_op.register("SoftmaxOutput", OPSET_VERSION) +def convert_softmax_output(node, **kwargs): + """Map MXNet's SoftmaxOutput operator attributes to onnx's Softmax operator + and return the created node. + """ + from onnx.helper import make_node + name, input_nodes, attrs = get_inputs(node, kwargs) + + nodes = [ + make_node('Shape', [input_nodes[0]], [name+'_shape']), + make_node('Flatten', [input_nodes[0]], [name+'_flat'], axis=1), + make_node('Softmax', [name+'_flat'], [name+'_sm'], axis=1), + make_node('Reshape', [name+'_sm', name+'_shape'], [name]) + ] + + return nodes + diff --git a/tests/python-pytest/onnx/test_operators.py b/tests/python-pytest/onnx/test_operators.py index 0fc5f01c01d3..3c26f3e2fd28 100644 --- a/tests/python-pytest/onnx/test_operators.py +++ b/tests/python-pytest/onnx/test_operators.py @@ -1355,13 +1355,13 @@ def test_onnx_export_pad(tmp_path, dtype, mode, params): # Note that due to ONNX limitation, the behavior for when inputs > 2-D is different from that of # MXNet @pytest.mark.parametrize('dtype', ['float32', 'float64']) -@pytest.mark.parametrize('params', [((4,), (4,), False, False), - ((4, 5), (5, 6), False, False), +@pytest.mark.parametrize('params', [((4, 5), (5, 6), False, False), ((5, 4), (5, 6), True, False), ((5, 4), (6, 5), True, True), ((4, 5), (6, 5), False, True), ((4, 5), (5), False, False), - ((4), (4, 5), False, False)]) + ((4,), (4, 5), False, False), + ((4, 5), (5,), False, False)]) def test_onnx_export_dot(tmp_path, dtype, params): A = mx.random.uniform(0, 1, params[0], dtype=dtype) B = mx.random.uniform(0, 1, params[1], dtype=dtype) @@ -1403,7 +1403,7 @@ def test_onnx_export_LogisticRegressionOutput(tmp_path, dtype, shape): @pytest.mark.parametrize('dtype', ['float32', 'float64']) -@pytest.mark.parametrize('shape', [(3, 4, 5), (6, 7), (3, 4, 5, 6, 7)]) +@pytest.mark.parametrize('shape', [(4, 5, 6), (6, 7), (3, 4, 5, 6, 7)]) def test_onnx_export_SoftmaxOutput(tmp_path, dtype, shape): x = mx.random.uniform(0, 1, shape, dtype=dtype) y = mx.nd.zeros(shape[:-1], dtype=dtype) @@ -1451,7 +1451,7 @@ def test_onnx_export_LRN(tmp_path, dtype, shape, alpha, beta, knorm, nsize): @pytest.mark.parametrize('shape2', [None, (10, 10, 16, 16)]) def test_onnx_export_Crop(tmp_path, dtype, shape, h_w, offset, shape2): x = mx.random.uniform(0, 1, shape, dtype=dtype) - M = def_model('Crop', h_w=h_w, offset=offset, center_crop=True) + M = def_model('Crop', h_w=h_w, offset=offset, center_crop=False) if shape2 is not None: y = mx.random.uniform(0, 1, shape2, dtype=dtype) op_export_test('Crop', M, [x, y], tmp_path) From 21d77151c38f129850c480ac057deaa0af069491 Mon Sep 17 00:00:00 2001 From: Zhaoqi Zhu Date: Wed, 21 Apr 2021 18:19:34 -0700 Subject: [PATCH 4/5] Update _op_translations_opset13.py --- .../onnx/mx2onnx/_op_translations/_op_translations_opset13.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset13.py b/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset13.py index c77d0ba5b68e..3d32c9a732ff 100644 --- a/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset13.py +++ b/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset13.py @@ -1552,7 +1552,7 @@ def convert_softmax_output(node, **kwargs): and return the created node. """ from onnx.helper import make_node - name, input_nodes, attrs = get_inputs(node, kwargs) + name, input_nodes, _ = get_inputs(node, kwargs) nodes = [ make_node('Shape', [input_nodes[0]], [name+'_shape']), @@ -1562,4 +1562,3 @@ def convert_softmax_output(node, **kwargs): ] return nodes - From e8ec7096c8b08a0c9fc9d30f7102b463f516c07e Mon Sep 17 00:00:00 2001 From: Zhaoqi Zhu Date: Fri, 23 Apr 2021 13:39:37 -0700 Subject: [PATCH 5/5] Update test_operators.py --- tests/python-pytest/onnx/test_operators.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/python-pytest/onnx/test_operators.py b/tests/python-pytest/onnx/test_operators.py index 91eb308538bb..2cdf9f95bf6b 100644 --- a/tests/python-pytest/onnx/test_operators.py +++ b/tests/python-pytest/onnx/test_operators.py @@ -1393,7 +1393,7 @@ def test_onnx_export_linalg_gemm2(tmp_path, dtype, alpha, params): op_export_test('_linalg_gemm2', M, [A, B], tmp_path) -@pytest.mark.parametrize('dtype', ['float32', 'float64']) +@pytest.mark.parametrize('dtype', ['float32']) @pytest.mark.parametrize('shape', [(3, 4, 5), (6, 7), (8,)]) def test_onnx_export_LogisticRegressionOutput(tmp_path, dtype, shape): x = mx.random.uniform(0, 1, shape, dtype=dtype) @@ -1412,7 +1412,7 @@ def test_onnx_export_SoftmaxOutput(tmp_path, dtype, shape): # Due to ONNX limitation, L2Normalization only supports channel mode for now -@pytest.mark.parametrize('dtype', ['float32', 'float64']) +@pytest.mark.parametrize('dtype', ['float32']) @pytest.mark.parametrize('shape', [(3, 4, 5), (3, 4, 5, 6, 7)]) def test_onnx_export_L2Normalization(tmp_path, dtype, shape): x = mx.random.uniform(0, 1, shape, dtype=dtype)