diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index b9a7ef03528c..985f07dd7e89 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -229,34 +229,45 @@ def convert_weights_and_inputs(node, **kwargs): return [tval_node] -@mx_op.register("Convolution") +@mx_op.register('Convolution') def convert_convolution(node, **kwargs): """Map MXNet's convolution operator attributes to onnx's Conv operator and return the created node. """ + from onnx.helper import make_node name, input_nodes, attrs = get_inputs(node, kwargs) - kernel_dims = list(parse_helper(attrs, "kernel")) - stride_dims = list(parse_helper(attrs, "stride", [1, 1])) - pad_dims = list(parse_helper(attrs, "pad", [0, 0])) - num_group = int(attrs.get("num_group", 1)) - dilations = list(parse_helper(attrs, "dilate", [1, 1])) + kernel = convert_string_to_list(attrs.get('kernel', '()')) + stride = convert_string_to_list(attrs.get('stride', '(1, 1)')) + dilate = convert_string_to_list(attrs.get('dilate', '(1, 1)')) + pad = convert_string_to_list(attrs.get('pad', '(0, 0)')) + num_group = int(attrs.get('num_group', 1)) + no_bias = attrs.get('no_bias', 'False') + layout = attrs.get('layout', 'NCHW') - pad_dims = pad_dims + pad_dims + if layout != 'NCHW': + raise NotImplementedError('Pooling currently does not support layout!=\'NCHW\'') - conv_node = onnx.helper.make_node( - "Conv", - inputs=input_nodes, - outputs=[name], - kernel_shape=kernel_dims, - strides=stride_dims, - dilations=dilations, - pads=pad_dims, - group=num_group, - name=name - ) + if no_bias == 'True': + assert len(input_nodes) == 2, 'Convolution takes 2 input if no_bias==True' + else: + assert len(input_nodes) == 3, 'Convolution takes 3 input if no_bias==False' + + kwargs_ = {} + if kernel: + kwargs_['kernel_shape'] = tuple(kernel) + if pad: + kwargs_['pads'] = tuple(pad) + tuple(pad) + if stride: + kwargs_['strides'] = stride + if dilate: + kwargs_['dilations'] = dilate - return [conv_node] + nodes = [ + make_node('Conv', input_nodes, [name], group=num_group, **kwargs_) + ] + + return nodes @mx_op.register("Deconvolution") @@ -679,92 +690,77 @@ def convert_linalg_gemm2(node, **kwargs): return [trans_a_node, trans_b_node, matmul_node] -@mx_op.register("Pooling") +@mx_op.register('Pooling') def convert_pooling(node, **kwargs): """Map MXNet's Pooling operator attributes to onnx's MaxPool/AveragePool/GlobalMaxPool/GlobalAveragePool operators - based on the input node's attributes and return the created node. """ - opset_version = kwargs["opset_version"] + from onnx.helper import make_node name, input_nodes, attrs = get_inputs(node, kwargs) - kernel = eval(attrs["kernel"]) - pool_type = attrs["pool_type"] if attrs.get("pool_type") else "max" - stride = eval(attrs["stride"]) if attrs.get("stride") else (1, 1) - global_pool = get_boolean_attribute_value(attrs, "global_pool") - p_value = attrs.get('p_value', 'None') - + kernel = convert_string_to_list(attrs.get('kernel', '()')) + pool_type = attrs.get('pool_type', 'max') + global_pool = attrs.get('global_pool', 'False') + _ = attrs.get('cudnn_off', 'False') pooling_convention = attrs.get('pooling_convention', 'valid') - ceil_mode = False - if pooling_convention == 'full': - if opset_version < 10: - pooling_warning = "Pooling: ONNX lower than 1.5.0 doesn't support pooling_convention. " \ - "This might lead to shape or accuracy issues. " \ - "https://github.com/onnx/onnx/issues/549" - logging.warning(pooling_warning) - ceil_mode = True - - pad_dims = list(parse_helper(attrs, "pad", [0, 0])) - pad_dims = pad_dims + pad_dims - pool_types = {"max": "MaxPool", "avg": "AveragePool", "lp": "LpPool"} - global_pool_types = {"max": "GlobalMaxPool", "avg": "GlobalAveragePool", - "lp": "GlobalLpPool"} + stride = convert_string_to_list(attrs.get('stride', '(1, 1)')) + pad = convert_string_to_list(attrs.get('pad', '()')) + p_value = int(attrs.get('p_value', '0')) + count_include_pad = attrs.get('count_include_pad', 'True') + layout = attrs.get('layout', 'NCHW') + + if pooling_convention == 'same': + raise NotImplementedError('Pooling currently does not support ' + 'pooling_convention==\'same\'') + if pool_type == 'sum': + raise NotImplementedError('Pooling currently does not support pool_type==\'sum\'') + if pool_type == 'lp' and global_pool == 'False' and pooling_convention != 'valid': + raise NotImplementedError('Pooling currently does not support ' + 'pooling_convention!=\'valid\' when pool_type==\'lp\' and global_pool==False') + if layout != 'NCHW': + raise NotImplementedError('Pooling currently does not support layout!=\'NCHW\'') + + kwargs_ = {} + if kernel: + kwargs_['kernel_shape'] = tuple(kernel) + if pad: + kwargs_['pads'] = tuple(pad) + tuple(pad) + if stride: + kwargs_['strides'] = stride + + ceil_mode = 1 if pooling_convention == 'full' else 0 + count_include_pad = 1 if count_include_pad == 'True' else 0 - if pool_type == 'lp' and p_value == 'None': - raise AttributeError('ONNX requires a p value for LpPool and GlobalLpPool') - - if global_pool: - if pool_type == 'lp': - node = onnx.helper.make_node( - global_pool_types[pool_type], - input_nodes, # input - [name], - p=int(p_value), - name=name - ) - else: - node = onnx.helper.make_node( - global_pool_types[pool_type], - input_nodes, # input - [name], - name=name - ) + nodes = [] + if pool_type == 'avg' and global_pool == 'False': + nodes += [ + make_node('AveragePool', [input_nodes[0]], [name], ceil_mode=ceil_mode, + count_include_pad=count_include_pad, **kwargs_) + ] + elif pool_type == 'max' and global_pool == 'False': + nodes += [ + make_node('MaxPool', [input_nodes[0]], [name], ceil_mode=ceil_mode, **kwargs_) + ] + elif pool_type == 'lp' and global_pool == 'False': + nodes += [ + make_node('LpPool', [input_nodes[0]], [name], p=p_value, **kwargs_) + ] + elif pool_type == 'avg' and global_pool == 'True': + nodes += [ + make_node('GlobalAveragePool', [input_nodes[0]], [name]) + ] + elif pool_type == 'max' and global_pool == 'True': + nodes += [ + make_node('GlobalMaxPool', [input_nodes[0]], [name]) + ] + elif pool_type == 'lp' and global_pool == 'True': + nodes += [ + make_node('GlobalLpPool', [input_nodes[0]], [name], p=p_value) + ] else: - if pool_type == 'lp': - node = onnx.helper.make_node( - pool_types[pool_type], - input_nodes, # input - [name], - p=int(p_value), - kernel_shape=kernel, - pads=pad_dims, - strides=stride, - name=name - ) - else: - if opset_version >= 10: - node = onnx.helper.make_node( - pool_types[pool_type], - input_nodes, # input - [name], - kernel_shape=kernel, - pads=pad_dims, - strides=stride, - name=name, - ceil_mode=ceil_mode - ) - else: - node = onnx.helper.make_node( - pool_types[pool_type], - input_nodes, # input - [name], - kernel_shape=kernel, - pads=pad_dims, - strides=stride, - name=name - ) + raise NotImplementedError('Unknown parameter values in Pooling') - return [node] + return nodes @mx_op.register("exp") diff --git a/tests/python-pytest/onnx/test_operators.py b/tests/python-pytest/onnx/test_operators.py index de3b1e93461d..1ff982100dd3 100644 --- a/tests/python-pytest/onnx/test_operators.py +++ b/tests/python-pytest/onnx/test_operators.py @@ -39,7 +39,7 @@ def hybrid_forward(self, F, *inputs): return func(*inputs, **params) return Model -def op_export_test(model_name, Model, inputs, tmp_path, dummy_input=False): +def op_export_test(model_name, Model, inputs, tmp_path, dummy_input=False, onnx_map=None): def export_to_onnx(model, model_name, inputs): model_path = '{}/{}'.format(tmp_path, model_name) model.export(model_path, epoch=0) @@ -69,9 +69,11 @@ def onnx_rt(onnx_file, inputs): pred_nat = pred_nat[0] if isinstance(pred_nat, list): for i in range(len(pred_nat)): - assert_almost_equal(pred_nat[i], pred_onx[i], equal_nan=True) + pred_onx_i = onnx_map(pred_onx[i]) if onnx_map else pred_onx[i] + assert_almost_equal(pred_nat[i], pred_onx_i, equal_nan=True) else: - assert_almost_equal(pred_nat, pred_onx[0], equal_nan=True) + pred_onx = onnx_map(pred_onx[0]) if onnx_map else pred_onx[0] + assert_almost_equal(pred_nat, pred_onx, equal_nan=True) def test_onnx_export_abs(tmp_path): @@ -752,6 +754,107 @@ def test_onnx_export_batch_dot(tmp_path, dtype, transpose_a, transpose_b): op_export_test('batch_dot2', M2, [x2, y2], tmp_path) +@pytest.mark.parametrize('dtype', ['float32']) +@pytest.mark.parametrize('shape', [(1, 3, 64, 64), (2, 1, 60, 60)]) +@pytest.mark.parametrize('count_include_pad', [True, False]) +@pytest.mark.parametrize('pooling_convention', ['full', 'valid']) +@pytest.mark.parametrize('kernel', [(3, 3), (4, 5), (14, 14)]) +@pytest.mark.parametrize('stride', [None, (1, 1), (2, 2), (3, 4), (4, 5)]) +@pytest.mark.parametrize('pad', [None, (1, 1), (3, 4), (4, 5)]) +def test_onnx_export_pooling_avg(tmp_path, dtype, shape, count_include_pad, pooling_convention, + kernel, stride, pad): + # mxnet and onnxruntime has different implementation of count_include_pad on the left column + # and bottom row + if pooling_convention == 'full' and count_include_pad == True: + return + # onnxruntime requires that pad is smaller than kernel + if pad and pad[0] >= kernel[0] and pad[1] >= kernel[1]: + return + x = mx.random.uniform(0, 1, shape, dtype=dtype) + kwargs = {} + if kernel: + kwargs['kernel'] = kernel + if stride: + kwargs['stride'] = stride + if pad: + kwargs['pad'] = pad + M = def_model('Pooling', count_include_pad=count_include_pad, pool_type='avg', + pooling_convention=pooling_convention, **kwargs) + # Note here we use np.nan_to_num to map the onnx output because onnxruntime AveragePool will + # output NaN in some edge cases where mxnet outputs 0 + op_export_test('pooling_avg', M, [x], tmp_path, onnx_map=np.nan_to_num) + + +@pytest.mark.parametrize('dtype', ['float32']) +@pytest.mark.parametrize('shape', [(1, 3, 64, 64), (2, 1, 60, 60)]) +@pytest.mark.parametrize('pooling_convention', ['full', 'valid']) +@pytest.mark.parametrize('kernel', [(3, 3), (4, 5), (14, 14)]) +@pytest.mark.parametrize('stride', [None, (1, 1), (2, 2), (3, 4), (4, 5)]) +@pytest.mark.parametrize('pad', [None, (1, 1), (3, 4), (4, 5)]) +def test_onnx_export_pooling_max(tmp_path, dtype, shape, pooling_convention, kernel, stride, pad): + # onnxruntime requires that pad is smaller than kernel + if pad and pad[0] >= kernel[0] and pad[1] >= kernel[1]: + return + x = mx.random.uniform(0, 1, shape, dtype=dtype) + kwargs = {} + if kernel: + kwargs['kernel'] = kernel + if stride: + kwargs['stride'] = stride + if pad: + kwargs['pad'] = pad + M = def_model('Pooling', pool_type='max', pooling_convention=pooling_convention, **kwargs) + op_export_test('pooling_max', M, [x], tmp_path) + + +@pytest.mark.parametrize('dtype', ['float32']) +@pytest.mark.parametrize('shape', [(1, 3, 64, 64), (2, 1, 60, 60)]) +@pytest.mark.parametrize('p_value', [1, 2]) +@pytest.mark.parametrize('kernel', [(3, 3), (4, 5), (14, 14)]) +@pytest.mark.parametrize('stride', [None, (1, 1), (2, 2), (3, 4), (4, 5)]) +@pytest.mark.parametrize('pad', [None, (1, 1), (3, 4), (4, 5)]) +def test_onnx_export_pooling_lp(tmp_path, dtype, shape, p_value, kernel, stride, pad): + # onnxruntime requires that pad is smaller than kernel + if pad and pad[0] >= kernel[0] and pad[1] >= kernel[1]: + return + x = mx.random.uniform(0, 1, shape, dtype=dtype) + kwargs = {} + if kernel: + kwargs['kernel'] = kernel + if stride: + kwargs['stride'] = stride + if pad: + kwargs['pad'] = pad + M = def_model('Pooling', pool_type='lp', pooling_convention='valid', + p_value=p_value, **kwargs) + op_export_test('pooling_lp', M, [x], tmp_path) + + +@pytest.mark.parametrize('dtype', ['float32']) +@pytest.mark.parametrize('shape', [(1, 3, 64, 64), (2, 1, 60, 60)]) +@pytest.mark.parametrize('pool_type', ['avg', 'max', 'lp']) +@pytest.mark.parametrize('p_value', [1, 2]) +@pytest.mark.parametrize('kernel', [(3, 3), (14, 14)]) +@pytest.mark.parametrize('stride', [None, (3, 4)]) +@pytest.mark.parametrize('pad', [None, (3, 4)]) +def test_onnx_export_pooling_global(tmp_path, dtype, shape, pool_type, p_value, kernel, stride, pad): + # onnxruntime requires that pad is smaller than kernel + if pad and pad[0] >= kernel[0] and pad[1] >= kernel[1]: + return + x = mx.random.uniform(0, 1, shape, dtype=dtype) + kwargs = {} + if kernel: + kwargs['kernel'] = kernel + if stride: + kwargs['stride'] = stride + if pad: + kwargs['pad'] = pad + # kernel, stride, and pad should have no effect on the results + M = def_model('Pooling', global_pool=True, pool_type=pool_type, pooling_convention='valid', + p_value=p_value, **kwargs) + op_export_test('pooling_global', M, [x], tmp_path) + + @pytest.mark.parametrize('dtype', ['float16', 'float32']) def test_onnx_export_log2(tmp_path, dtype): x = mx.random.normal(0, 10, (2, 3, 4, 5)).astype(dtype) @@ -779,3 +882,36 @@ def test_onnx_export_broadcast_mul(tmp_path, dtype): x = mx.nd.array([[1,2,3],[4,5,6]], dtype=dtype) y = mx.nd.array([[0],[3]], dtype=dtype) op_export_test('broadcast_mul', M, [x, y], tmp_path) + + +@pytest.mark.parametrize('dtype', ['float32']) +@pytest.mark.parametrize('shape', [(1, 3, 64, 64), (2, 6, 60, 60)]) +@pytest.mark.parametrize('num_filter', [2, 4, 32]) +@pytest.mark.parametrize('num_group', [1, 2]) +@pytest.mark.parametrize('no_bias', [True, False]) +@pytest.mark.parametrize('kernel', [(3, 3), (4, 5), (14, 14)]) +@pytest.mark.parametrize('stride', [None, (1, 1), (2, 2), (3, 4), (4, 5)]) +@pytest.mark.parametrize('pad', [None, (1, 1), (3, 4), (4, 5)]) +@pytest.mark.parametrize('dilate', [None, (1, 1)]) +def test_onnx_export_convolution(tmp_path, dtype, shape, num_filter, num_group, no_bias, + kernel, stride, pad, dilate): + if shape[1] % num_group: + return + x = mx.random.uniform(0, 1, shape, dtype=dtype) + w_shape = (num_filter,) + (shape[1] // num_group,) + kernel + w = mx.random.uniform(0, 1, w_shape, dtype=dtype) + b_shape = (num_filter) + b = mx.random.uniform(0, 1, b_shape, dtype=dtype) + kwargs = {} + if kernel: + kwargs['kernel'] = kernel + if stride: + kwargs['stride'] = stride + if pad: + kwargs['pad'] = pad + if dilate: + kwargs['dilate'] = dilate + M = def_model('Convolution', num_filter=num_filter, num_group=num_group, no_bias=no_bias, + **kwargs) + inputs = [x, w] if no_bias else [x, w, b] + op_export_test('convolution', M, inputs, tmp_path)