diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index 3576242e0d77..c1c150fe7405 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -1122,19 +1122,19 @@ def convert_dropout(node, **kwargs): """ from onnx.helper import make_node name, input_nodes, attrs = get_inputs(node, kwargs) - opset_version = kwargs["opset_version"] - probability = float(attrs.get("p", 0.5)) + _ = float(attrs.get("p", 0.5)) + _ = convert_string_to_list(attrs.get("axes", "None")) + mode = attrs.get('mode', 'training') - if opset_version >= 12: - # opset >= 12 requires the ratio to be an input - nodes = [ - create_const_scalar_node(name+"_ratio0", np.float32(probability), kwargs), - make_node("Dropout", [input_nodes[0], name+"_ratio0"], [name], name=name) - ] - return nodes - else: - return [make_node("Dropout", input_nodes, [name], ratio=probability, name=name)] + if mode != 'training': + raise NotImplementedError("Dropout does not currently support mode!=\'training\'") + + nodes = [ + make_node('Identity', [input_nodes[0]], [name]) + ] + + return nodes @mx_op.register("Flatten") diff --git a/tests/python-pytest/onnx/test_operators.py b/tests/python-pytest/onnx/test_operators.py index f4b44e58f188..53850fd9fbb5 100644 --- a/tests/python-pytest/onnx/test_operators.py +++ b/tests/python-pytest/onnx/test_operators.py @@ -896,6 +896,15 @@ def test_onnx_export_broadcast_mul(tmp_path, dtype): op_export_test('broadcast_mul', M, [x, y], tmp_path) +@pytest.mark.parametrize('dtype', ['float16', 'float32', 'float64']) +@pytest.mark.parametrize('shape', [(3, 4, 5), (1, 2, 3, 2, 1)]) +@pytest.mark.parametrize('p', [0, 0.1, 0.5, 1]) +def test_onnx_export_dropout(tmp_path, dtype, shape, p): + x = mx.random.uniform(-100, 100, shape=shape).astype(dtype) + M = def_model('Dropout', p=p) + op_export_test('Dropuout', M, [x], tmp_path) + + @pytest.mark.parametrize('dtype', ['float32']) @pytest.mark.parametrize('shape', [(1, 3, 64, 64), (2, 6, 60, 60)]) @pytest.mark.parametrize('num_filter', [2, 4, 32])