From 8d87a2db6228a08c31d2515a44c08dc7980b120f Mon Sep 17 00:00:00 2001 From: zha0q1 Date: Fri, 23 Apr 2021 21:20:08 +0000 Subject: [PATCH 1/3] fix log_softmax for opset 12 --- .../_op_translations_opset12.py | 26 ++++++++++------ .../_op_translations_opset13.py | 31 ++++++++++++++++++- 2 files changed, 46 insertions(+), 11 deletions(-) diff --git a/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset12.py b/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset12.py index ffc89d440e1e..677503e32e08 100644 --- a/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset12.py +++ b/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset12.py @@ -2371,22 +2371,28 @@ def convert_logsoftmax(node, **kwargs): """Map MXNet's log_softmax operator attributes to onnx's LogSoftMax operator and return the created node. """ + from onnx.helper import make_node name, input_nodes, attrs = get_inputs(node, kwargs) # Converting to int axis = int(attrs.get("axis", -1)) - temp = attrs.get("temperature", 'None') + temp = attrs.get('temperature', 'None') + use_length = attrs.get('use_length', 'False') + if temp != 'None': - raise AttributeError("LogSoftMax: ONNX supports only temperature=None") + raise AttributeError('LogSoftMax currently does not support temperature!=None') - node = onnx.helper.make_node( - 'LogSoftmax', - input_nodes, - [name], - axis=axis, - name=name - ) - return [node] + if use_length in ['1', 'True']: + raise AttributeError('LogSoftMax currently does not support use_length==True') + + nodes = [ + make_node('Exp', [input_nodes[0]], [name+'_exp']), + make_node('ReduceSum', [name+'_exp'], [name+'_rsum'], axes=[axis], keepdims=1), + make_node('Div', [name+'_exp', name+'_rsum'], [name+'_div']), + make_node('Log', [name+'_div'], [name]) + ] + + return nodes @mx_op.register("norm") def convert_norm(node, **kwargs): diff --git a/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset13.py b/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset13.py index ed4ddbc68bce..7e6ce380b068 100644 --- a/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset13.py +++ b/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset13.py @@ -1597,4 +1597,33 @@ def convert_norm(node, **kwargs): make_node('Reshape', [name+'_norm', name+'_1'], [name]) ] return nodes - \ No newline at end of file + + +@mx_op.register("log_softmax", OPSET_VERSION) +def convert_logsoftmax(node, **kwargs): + """Map MXNet's log_softmax operator attributes to onnx's LogSoftMax operator + and return the created node. + """ + name, input_nodes, attrs = get_inputs(node, kwargs) + + # Converting to int + axis = int(attrs.get("axis", -1)) + temp = attrs.get('temperature', 'None') + use_length = attrs.get('use_length', 'False') + + if temp != 'None': + raise AttributeError('LogSoftMax currently does not support temperature!=None') + + if use_length in ['1', 'True']: + raise AttributeError('LogSoftMax currently does not support use_length==True') + + node = onnx.helper.make_node( + 'LogSoftmax', + input_nodes, + [name], + axis=axis, + name=name + ) + + return [node] + From 85f0fd3a7afc93a065c347de824f2d8a540eb11c Mon Sep 17 00:00:00 2001 From: Zhaoqi Zhu Date: Fri, 23 Apr 2021 15:39:11 -0700 Subject: [PATCH 2/3] Update _op_translations_opset12.py --- .../onnx/mx2onnx/_op_translations/_op_translations_opset12.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset12.py b/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset12.py index 677503e32e08..3e9be5d9b20f 100644 --- a/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset12.py +++ b/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset12.py @@ -2378,13 +2378,13 @@ def convert_logsoftmax(node, **kwargs): axis = int(attrs.get("axis", -1)) temp = attrs.get('temperature', 'None') use_length = attrs.get('use_length', 'False') - + if temp != 'None': raise AttributeError('LogSoftMax currently does not support temperature!=None') if use_length in ['1', 'True']: raise AttributeError('LogSoftMax currently does not support use_length==True') - + nodes = [ make_node('Exp', [input_nodes[0]], [name+'_exp']), make_node('ReduceSum', [name+'_exp'], [name+'_rsum'], axes=[axis], keepdims=1), From d15ed04655daff785bea99bf3602e8fec439ad72 Mon Sep 17 00:00:00 2001 From: Zhaoqi Zhu Date: Fri, 23 Apr 2021 16:43:49 -0700 Subject: [PATCH 3/3] Update _op_translations_opset13.py --- .../onnx/mx2onnx/_op_translations/_op_translations_opset13.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset13.py b/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset13.py index 7e6ce380b068..02e7d45464c0 100644 --- a/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset13.py +++ b/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset13.py @@ -1626,4 +1626,3 @@ def convert_logsoftmax(node, **kwargs): ) return [node] -