diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index 7cf856c767fa..11e75d9a6000 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -60,21 +60,16 @@ import logging import numpy as np from .export_onnx import MXNetGraph as mx_op - -def import_onnx_modules(): - """ To make sure ONNX is runtime dependency, it is imported used only when needed""" - try: - import onnx - except ImportError: - raise ImportError("Onnx and protobuf need to be installed. " - + "Instructions to install - https://github.com/onnx/onnx") - return onnx +try: + import onnx +except ImportError: + onnx = None def parse_helper(attrs, attrs_name, alt_value=None): """Helper function to parse operator attributes in required format.""" tuple_re = re.compile('\([0-9L|,| ]+\)') - if attrs is None: + if not attrs: return alt_value attrs_str = None if attrs.get(attrs_name) is None else str(attrs.get(attrs_name)) if attrs_str is None: @@ -135,12 +130,39 @@ def get_boolean_attribute_value(attrs, attr_name): """ return 1 if attrs.get(attr_name, 0) in ["True", "1"] else 0 +def get_inputs(node, kwargs): + """Helper function to get inputs""" + name = node["name"] + proc_nodes = kwargs["proc_nodes"] + index_lookup = kwargs["index_lookup"] + inputs = node["inputs"] + attrs = node.get("attrs", {}) + + input_nodes = [] + for ip in inputs: + input_node_id = index_lookup[ip[0]] + input_nodes.append(proc_nodes[input_node_id].name) + + return name, input_nodes, attrs + +def create_basic_op_node(op_name, node, kwargs): + """Helper function to create a basic operator + node that doesn't contain op specific attrs""" + name, input_nodes, _ = get_inputs(node, kwargs) + + node = onnx.helper.make_node( + op_name, + input_nodes, + [name], + name=name + ) + return [node] + @mx_op.register("null") def convert_weights_and_inputs(node, **kwargs): """Helper function to convert weights and inputs. """ - onnx = import_onnx_modules() - name = node["name"] + name, _, _ = get_inputs(node, kwargs) if kwargs["is_input"] is False: weights = kwargs["weights"] @@ -172,20 +194,7 @@ def convert_convolution(node, **kwargs): """Map MXNet's convolution operator attributes to onnx's Conv operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - inputs = node["inputs"] - - num_inputs = len(inputs) - - proc_nodes = kwargs["proc_nodes"] - input_node = proc_nodes[kwargs["index_lookup"][inputs[0][0]]].name - weights_node = proc_nodes[kwargs["index_lookup"][inputs[1][0]]].name - - if num_inputs > 2: - bias_node = proc_nodes[kwargs["index_lookup"][inputs[2][0]]].name - - attrs = node.get("attrs") + name, input_nodes, attrs = get_inputs(node, kwargs) kernel_dims = list(parse_helper(attrs, "kernel")) stride_dims = list(parse_helper(attrs, "stride", [1, 1])) @@ -195,10 +204,6 @@ def convert_convolution(node, **kwargs): pad_dims = pad_dims + pad_dims - input_nodes = [input_node, weights_node] - if num_inputs > 2: - input_nodes.append(bias_node) - conv_node = onnx.helper.make_node( "Conv", inputs=input_nodes, @@ -219,32 +224,15 @@ def convert_fully_connected(node, **kwargs): """Map MXNet's FullyConnected operator attributes to onnx's Gemm operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - inputs = node["inputs"] - attrs = node["attrs"] + name, input_nodes, attrs = get_inputs(node, kwargs) + initializer = kwargs["initializer"] no_bias = get_boolean_attribute_value(attrs, "no_bias") - input_node_id = kwargs["index_lookup"][inputs[0][0]] - weight_node_id = kwargs["index_lookup"][inputs[1][0]] - - proc_nodes = kwargs["proc_nodes"] - - input_node = proc_nodes[input_node_id] - input_name = input_node.name - - weights_node = proc_nodes[weight_node_id] - weights_name = weights_node.name - fcnode = [] - if no_bias == 0: - bias_node_id = kwargs["index_lookup"][inputs[2][0]] - bias_node = proc_nodes[bias_node_id] - bias_name = bias_node.name - else: + if no_bias: data_type = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype('int64')] bias_name = "bias" + str(kwargs["idx"]) tensor_node = onnx.helper.make_tensor_value_info(bias_name, data_type, (1,)) @@ -257,11 +245,12 @@ def convert_fully_connected(node, **kwargs): raw=False, ) ) + input_nodes.append(bias_name) fcnode.append(tensor_node) node = onnx.helper.make_node( "Gemm", - [input_name, weights_name, bias_name], # input (A, B, C) - C can be in place + input_nodes, # input (A, B, C) - C can be in place [name], # output alpha=1.0, beta=1.0, @@ -280,37 +269,14 @@ def convert_batchnorm(node, **kwargs): """Map MXNet's BatchNorm operator attributes to onnx's BatchNormalization operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - proc_nodes = kwargs["proc_nodes"] - inputs = node["inputs"] + name, input_nodes, attrs = get_inputs(node, kwargs) - attrs = node["attrs"] - momentum = float(node.get("attrs", {}).get("momentum", 0.9)) + momentum = float(attrs.get("momentum", 0.9)) eps = float(attrs.get("eps", 0.001)) - data_idx = kwargs["index_lookup"][inputs[0][0]] - gamma_idx = kwargs["index_lookup"][inputs[1][0]] - beta_idx = kwargs["index_lookup"][inputs[2][0]] - moving_mean_idx = kwargs["index_lookup"][inputs[3][0]] - moving_var_idx = kwargs["index_lookup"][inputs[4][0]] - - data_node = proc_nodes[data_idx].name - gamma_node = proc_nodes[gamma_idx].name - beta_node = proc_nodes[beta_idx].name - - mov_mean_node = proc_nodes[moving_mean_idx] - mov_mean_node = mov_mean_node.name - mov_var_node = proc_nodes[moving_var_idx].name - bn_node = onnx.helper.make_node( "BatchNormalization", - [data_node, - gamma_node, # scale - beta_node, # bias - mov_mean_node, - mov_var_node - ], + input_nodes, [name], name=name, epsilon=eps, @@ -327,140 +293,49 @@ def convert_tanh(node, **kwargs): """Map MXNet's tanh operator attributes to onnx's Tanh operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - inputs = node["inputs"] - input_node_idx = kwargs["index_lookup"][inputs[0][0]] - proc_nodes = kwargs["proc_nodes"] - input_node = proc_nodes[input_node_idx].name - - node = onnx.helper.make_node( - 'Tanh', - [input_node], - [name], - name=name - ) - return [node] + return create_basic_op_node('Tanh', node, kwargs) @mx_op.register("cos") def convert_cos(node, **kwargs): """Map MXNet's cos operator attributes to onnx's Cos operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - inputs = node["inputs"] - input_node_idx = kwargs["index_lookup"][inputs[0][0]] - proc_nodes = kwargs["proc_nodes"] - input_node = proc_nodes[input_node_idx].name - - node = onnx.helper.make_node( - 'Cos', - [input_node], - [name], - name=name - ) - return [node] + return create_basic_op_node('Cos', node, kwargs) @mx_op.register("sin") def convert_sin(node, **kwargs): """Map MXNet's sin operator attributes to onnx's Sin operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - inputs = node["inputs"] - input_node_idx = kwargs["index_lookup"][inputs[0][0]] - proc_nodes = kwargs["proc_nodes"] - input_node = proc_nodes[input_node_idx].name - - node = onnx.helper.make_node( - 'Sin', - [input_node], - [name], - name=name - ) - return [node] + return create_basic_op_node('Sin', node, kwargs) @mx_op.register("tan") def convert_tan(node, **kwargs): """Map MXNet's tan operator attributes to onnx's tan operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - inputs = node["inputs"] - input_node_idx = kwargs["index_lookup"][inputs[0][0]] - proc_nodes = kwargs["proc_nodes"] - input_node = proc_nodes[input_node_idx].name - - node = onnx.helper.make_node( - 'Tan', - [input_node], - [name], - name=name - ) - return [node] + return create_basic_op_node('Tan', node, kwargs) @mx_op.register("arccos") def convert_acos(node, **kwargs): """Map MXNet's acos operator attributes to onnx's acos operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - inputs = node["inputs"] - input_node_idx = kwargs["index_lookup"][inputs[0][0]] - proc_nodes = kwargs["proc_nodes"] - input_node = proc_nodes[input_node_idx].name - - node = onnx.helper.make_node( - 'Acos', - [input_node], - [name], - name=name - ) - return [node] + return create_basic_op_node('Acos', node, kwargs) @mx_op.register("arcsin") def convert_asin(node, **kwargs): """Map MXNet's asin operator attributes to onnx's asin operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - inputs = node["inputs"] - input_node_idx = kwargs["index_lookup"][inputs[0][0]] - proc_nodes = kwargs["proc_nodes"] - input_node = proc_nodes[input_node_idx].name - - node = onnx.helper.make_node( - 'Asin', - [input_node], - [name], - name=name - ) - return [node] + return create_basic_op_node('Asin', node, kwargs) @mx_op.register("arctan") def convert_atan(node, **kwargs): """Map MXNet's atan operator attributes to onnx's atan operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - inputs = node["inputs"] - input_node_idx = kwargs["index_lookup"][inputs[0][0]] - proc_nodes = kwargs["proc_nodes"] - input_node = proc_nodes[input_node_idx].name - - node = onnx.helper.make_node( - 'Atan', - [input_node], - [name], - name=name - ) - return [node] + return create_basic_op_node('Atan', node, kwargs) #Basic neural network functions @mx_op.register("sigmoid") @@ -468,58 +343,24 @@ def convert_sigmoid(node, **kwargs): """Map MXNet's sigmoid operator attributes to onnx's Sigmoid operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - inputs = node["inputs"] - input_node_idx = kwargs["index_lookup"][inputs[0][0]] - proc_nodes = kwargs["proc_nodes"] - input_node = proc_nodes[input_node_idx].name - - node = onnx.helper.make_node( - 'Sigmoid', - [input_node], - [name], - name=name - ) - return [node] + return create_basic_op_node('Sigmoid', node, kwargs) @mx_op.register("relu") def convert_relu(node, **kwargs): """Map MXNet's relu operator attributes to onnx's Relu operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - inputs = node["inputs"] - input_node_idx = kwargs["index_lookup"][inputs[0][0]] - proc_nodes = kwargs["proc_nodes"] - input_node = proc_nodes[input_node_idx].name - - node = onnx.helper.make_node( - 'Relu', - [input_node], - [name], - name=name - ) - - return [node] + return create_basic_op_node('Relu', node, kwargs) @mx_op.register("Activation") def convert_activation(node, **kwargs): """Map MXNet's Activation operator attributes to onnx's Tanh/Relu operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] + name, input_nodes, attrs = get_inputs(node, kwargs) - proc_nodes = kwargs["proc_nodes"] - attrs = node["attrs"] act_type = attrs["act_type"] - inputs = node["inputs"] - input_node_idx = kwargs["index_lookup"][inputs[0][0]] - input_node = proc_nodes[input_node_idx].output[0] - # Creating a dictionary here, but if this titlecase pattern # mxnet_name.title() act_types = { @@ -534,7 +375,7 @@ def convert_activation(node, **kwargs): if act_name: node = onnx.helper.make_node( act_name, - [input_node], + input_nodes, [name], name=name ) @@ -551,13 +392,7 @@ def convert_pad(node, **kwargs): """Map MXNet's pad operator attributes to onnx's Pad operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - attrs = node["attrs"] - proc_nodes = kwargs["proc_nodes"] - inputs = node["inputs"] - input_node_idx = kwargs["index_lookup"][inputs[0][0]] - input_node = proc_nodes[input_node_idx].name + name, input_nodes, attrs = get_inputs(node, kwargs) mxnet_pad_width = convert_string_to_list(attrs.get("pad_width")) onnx_pad_width = transform_padding(mxnet_pad_width) @@ -569,7 +404,7 @@ def convert_pad(node, **kwargs): if "constant_value" in attrs else 0.0 node = onnx.helper.make_node( 'Pad', - inputs=[input_node], + inputs=input_nodes, outputs=[name], mode='constant', value=pad_value, @@ -579,7 +414,7 @@ def convert_pad(node, **kwargs): else: node = onnx.helper.make_node( 'Pad', - inputs=[input_node], + inputs=input_nodes, outputs=[name], mode=pad_mode, pads=onnx_pad_width, @@ -591,8 +426,6 @@ def convert_pad(node, **kwargs): def create_helper_trans_node(op_name, input_node, node_name): """create extra transpose node for dot operator""" - onnx = import_onnx_modules() - node_name = op_name + "_" + node_name trans_node = onnx.helper.make_node( 'Transpose', @@ -608,17 +441,8 @@ def convert_dot(node, **kwargs): """Map MXNet's dot operator attributes to onnx's MatMul and Transpose operators based on the values set for transpose_a, transpose_b attributes.""" - onnx = import_onnx_modules() - proc_nodes = kwargs["proc_nodes"] - node_inputs = node["inputs"] - name = node["name"] - - input_a_idx = kwargs["index_lookup"][node_inputs[0][0]] - input_node_a = proc_nodes[input_a_idx].name - input_b_idx = kwargs["index_lookup"][node_inputs[1][0]] - input_node_b = proc_nodes[input_b_idx].name + name, input_nodes, attrs = get_inputs(node, kwargs) - attrs = node.get('attrs', {}) trans_a_node = None trans_b_node = None @@ -626,14 +450,12 @@ def convert_dot(node, **kwargs): trans_b = get_boolean_attribute_value(attrs, "transpose_b") op_name = "transpose" + str(kwargs["idx"]) - create_helper_trans_node(op_name, input_node_a, 'a') - create_helper_trans_node(op_name, input_node_b, 'b') if trans_a: - trans_a_node = create_helper_trans_node(op_name, input_node_a, 'a') + trans_a_node = create_helper_trans_node(op_name, input_nodes[0], 'a') input_node_a = op_name+"_a" if trans_b: - trans_b_node = create_helper_trans_node(op_name, input_node_b, 'b') + trans_b_node = create_helper_trans_node(op_name, input_nodes[1], 'b') input_node_b = op_name+"_b" matmul_node = onnx.helper.make_node( @@ -660,33 +482,19 @@ def convert_linalg_gemm2(node, **kwargs): transpose_a, transpose_b attributes. Return multiple nodes created. """ - onnx = import_onnx_modules() - proc_nodes = kwargs["proc_nodes"] - node_inputs = node["inputs"] - name = node["name"] - - input_a_idx = kwargs["index_lookup"][node_inputs[0][0]] - input_node_a = proc_nodes[input_a_idx].name - input_b_idx = kwargs["index_lookup"][node_inputs[1][0]] - input_node_b = proc_nodes[input_b_idx].name + name, input_nodes, attrs = get_inputs(node, kwargs) # Getting the attributes and assigning default values. - if "attrs" in node: - attrs = node["attrs"] - alpha = float(attrs["alpha"]) - trans_a = int(attrs["transpose_a"]) - trans_b = int(attrs["transpose_b"]) - else: - alpha = 1.0 - trans_a = 0 - trans_b = 0 + alpha = float(attrs.get("alpha", 1.0)) + trans_a = get_boolean_attribute_value(attrs, "transpose_a") + trans_b = get_boolean_attribute_value(attrs, "transpose_b") op_name = "transpose" + str(kwargs["idx"]) if alpha == 1.0 and trans_a == 0 and trans_b == 0: matmul_node = onnx.helper.make_node( 'MatMul', - inputs=[input_node_a, input_node_b], + inputs=input_nodes, outputs=[name], name=name ) @@ -696,14 +504,14 @@ def convert_linalg_gemm2(node, **kwargs): node_name = op_name+"_a" trans_a_node = onnx.helper.make_node( 'Transpose', - inputs=[input_node_a], + inputs=[input_nodes[0]], outputs=[op_name+"_a"], name=node_name ) matmul_node = onnx.helper.make_node( 'MatMul', - inputs=[node_name, input_node_b], + inputs=[node_name, input_nodes[1]], outputs=[name], name=name ) @@ -713,14 +521,14 @@ def convert_linalg_gemm2(node, **kwargs): node_name = op_name + "_b" trans_b_node = onnx.helper.make_node( 'Transpose', - inputs=[input_node_b], + inputs=[input_nodes[1]], outputs=[op_name+"_b"], name=node_name ) matmul_node = onnx.helper.make_node( 'MatMul', - inputs=[input_node_a, node_name], + inputs=[input_nodes[0], node_name], outputs=[name], name=name ) @@ -730,7 +538,7 @@ def convert_linalg_gemm2(node, **kwargs): node_name_a = op_name+"_a" trans_a_node = onnx.helper.make_node( 'Transpose', - inputs=[input_node_a], + inputs=[input_nodes[0]], outputs=[op_name+"_a"], name=node_name_a ) @@ -738,14 +546,14 @@ def convert_linalg_gemm2(node, **kwargs): node_name_b = op_name + "_b" trans_b_node = onnx.helper.make_node( 'Transpose', - inputs=[input_node_b], + inputs=[input_nodes[1]], outputs=[op_name+"_b"], name=node_name_b ) matmul_node = onnx.helper.make_node( 'MatMul', - inputs=[node_name_a, node_name_b], + inputs=input_nodes, outputs=[name], name=name ) @@ -759,19 +567,13 @@ def convert_pooling(node, **kwargs): MaxPool/AveragePool/GlobalMaxPool/GlobalAveragePool operators based on the input node's attributes and return the created node. """ - onnx = import_onnx_modules() - proc_nodes = kwargs["proc_nodes"] - attrs = node["attrs"] + name, input_nodes, attrs = get_inputs(node, kwargs) + kernel = eval(attrs["kernel"]) pool_type = attrs["pool_type"] stride = eval(attrs["stride"]) if attrs.get("stride") else None global_pool = get_boolean_attribute_value(attrs, "global_pool") - node_inputs = node["inputs"] - input_node_idx = kwargs["index_lookup"][node_inputs[0][0]] - input_node = proc_nodes[input_node_idx] - name = node["name"] - pooling_convention = attrs.get('pooling_convention', 'valid') if pooling_convention == 'full': @@ -789,14 +591,14 @@ def convert_pooling(node, **kwargs): if global_pool: node = onnx.helper.make_node( global_pool_types[pool_type], - [input_node.name], # input + input_nodes, # input [name], name=name ) else: node = onnx.helper.make_node( pool_types[pool_type], - [input_node.name], # input + input_nodes, # input [name], kernel_shape=kernel, pads=pad_dims, @@ -812,43 +614,14 @@ def convert_exp(node, **kwargs): """Map MXNet's exp operator attributes to onnx's Exp operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - proc_nodes = kwargs["proc_nodes"] - inputs = node["inputs"] - - input_node_id = kwargs["index_lookup"][inputs[0][0]] - input_node = proc_nodes[input_node_id].name - - node = onnx.helper.make_node( - "Exp", - [input_node], - [name], - name=name, - ) - return [node] - + return create_basic_op_node('Exp', node, kwargs) @mx_op.register("_copy") def convert_identity(node, **kwargs): """Map MXNet's _copy operator attributes to onnx's Identity operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - proc_nodes = kwargs["proc_nodes"] - inputs = node["inputs"] - - input_node_id = kwargs["index_lookup"][inputs[0][0]] - input_node = proc_nodes[input_node_id].name - - node = onnx.helper.make_node( - "Identity", - [input_node], - [name], - name=name, - ) - return [node] + return create_basic_op_node('Identity', node, kwargs) @mx_op.register("LeakyReLU") @@ -856,13 +629,7 @@ def convert_leakyrelu(node, **kwargs): """Map MXNet's LeakyReLU operator attributes to onnx's Elu/LeakyRelu/PRelu operators based on the input node's attributes and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - proc_nodes = kwargs["proc_nodes"] - inputs = node["inputs"] - input_node_id = kwargs["index_lookup"][inputs[0][0]] - input_node = proc_nodes[input_node_id].name - attrs = node["attrs"] + name, input_nodes, attrs = get_inputs(node, kwargs) act_type = attrs.get("act_type", "leaky") alpha = float(attrs.get("slope", 0.25)) @@ -870,25 +637,16 @@ def convert_leakyrelu(node, **kwargs): act_name = {"elu": "Elu", "leaky": "LeakyRelu", "prelu": "PRelu", "selu": "Selu"} - if act_type == "prelu": - alpha_node_index = kwargs["index_lookup"][inputs[1][0]] - alpha_node_name = proc_nodes[alpha_node_index].name - - node = onnx.helper.make_node( - act_name[act_type], - inputs=[input_node, alpha_node_name], - outputs=[name], - name=name) - elif act_type == "selu": + if act_type == "prelu" or act_type == "selu": node = onnx.helper.make_node( act_name[act_type], - inputs=[input_node], + inputs=input_nodes, outputs=[name], name=name) else: node = onnx.helper.make_node( act_name[act_type], - inputs=[input_node], + inputs=input_nodes, outputs=[name], name=name, alpha=alpha) @@ -901,18 +659,13 @@ def convert_softmax(node, **kwargs): """Map MXNet's softmax operator attributes to onnx's Softmax operator and return the created node. """ - onnx = import_onnx_modules() - inputs = node["inputs"] - input_idx = kwargs["index_lookup"][inputs[0][0]] - proc_nodes = kwargs["proc_nodes"] - input_node = proc_nodes[input_idx] + name, input_nodes, attrs = get_inputs(node, kwargs) - name = node["name"] - axis = int(node.get("attrs", {}).get("axis", -1)) + axis = int(attrs.get("axis", -1)) softmax_node = onnx.helper.make_node( "Softmax", - [input_node.name], + input_nodes, [name], axis=axis, name=name @@ -928,12 +681,10 @@ def convert_softmax_output(node, **kwargs): """Map MXNet's SoftmaxOutput operator attributes to onnx's Softmax operator and return the created node. """ - onnx = import_onnx_modules() - inputs = node["inputs"] - input1_idx = kwargs["index_lookup"][inputs[0][0]] - proc_nodes = kwargs["proc_nodes"] - input1 = proc_nodes[input1_idx] - name = node["name"] + name, _, _ = get_inputs(node, kwargs) + + input1_idx = kwargs["index_lookup"][node["inputs"][0][0]] + input1 = kwargs["proc_nodes"][input1_idx] softmax_node = onnx.helper.make_node( "Softmax", @@ -951,15 +702,12 @@ def convert_concat(node, **kwargs): """Map MXNet's Concat operator attributes to onnx's Concat operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - inputs = node["inputs"] - proc_nodes = kwargs["proc_nodes"] - input_names = [proc_nodes[kwargs["index_lookup"][i[0]]].name for i in inputs] - axis = int(node.get("attrs", {}).get("dim", 1)) + name, input_nodes, attrs = get_inputs(node, kwargs) + + axis = int(attrs.get("dim", 1)) concat_node = onnx.helper.make_node( "Concat", - input_names, + input_nodes, [name], axis=axis, name=name @@ -972,18 +720,15 @@ def convert_transpose(node, **kwargs): """Map MXNet's transpose operator attributes to onnx's Transpose operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - input_idx = kwargs["index_lookup"][node["inputs"][0][0]] - proc_nodes = kwargs["proc_nodes"] - input_node = proc_nodes[input_idx].name - axes = node.get("attrs", {}).get("axes", ()) + name, input_nodes, attrs = get_inputs(node, kwargs) + + axes = attrs.get("axes", ()) if axes: axes = tuple(map(int, re.findall(r'\d+', axes))) transpose_node = onnx.helper.make_node( "Transpose", - [input_node], + input_nodes, [name], perm=axes, name=name @@ -991,7 +736,7 @@ def convert_transpose(node, **kwargs): else: transpose_node = onnx.helper.make_node( "Transpose", - [input_node], + input_nodes, [name], name=name ) @@ -1004,21 +749,16 @@ def convert_lrn(node, **kwargs): """Map MXNet's LRN operator attributes to onnx's LRN operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - input_idx = kwargs["index_lookup"][node["inputs"][0][0]] - proc_nodes = kwargs["proc_nodes"] - input_node = proc_nodes[input_idx].name + name, input_nodes, attrs = get_inputs(node, kwargs) - attrs = node["attrs"] - alpha = float(attrs["alpha"]) if "alpha" in attrs else 0.0001 - beta = float(attrs["beta"]) if "beta" in attrs else 0.75 - bias = float(attrs["knorm"]) if "knorm" in attrs else 1.0 - size = int(attrs["nsize"]) + alpha = float(attrs.get("alpha", 0.0001)) + beta = float(attrs.get("beta", 0.75)) + bias = float(attrs.get("knorm", 1.0)) + size = int(attrs.get("nsize")) lrn_node = onnx.helper.make_node( "LRN", - inputs=[input_node], + inputs=input_nodes, outputs=[name], name=name, alpha=alpha, @@ -1035,11 +775,8 @@ def convert_l2normalization(node, **kwargs): """Map MXNet's L2Normalization operator attributes to onnx's LpNormalization operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - input_id = kwargs["index_lookup"][node["inputs"][0][0]] - input_name = kwargs["proc_nodes"][input_id].name - attrs = node["attrs"] + name, input_nodes, attrs = get_inputs(node, kwargs) + mode = attrs.get("mode", "instance") if mode != "channel": @@ -1047,7 +784,7 @@ def convert_l2normalization(node, **kwargs): l2norm_node = onnx.helper.make_node( "LpNormalization", - [input_name], + input_nodes, [name], axis=1, # channel only name=name @@ -1060,16 +797,13 @@ def convert_dropout(node, **kwargs): """Map MXNet's Dropout operator attributes to onnx's Dropout operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - input_id = kwargs["index_lookup"][node["inputs"][0][0]] - input_name = kwargs["proc_nodes"][input_id].name - attrs = node["attrs"] - probability = float(attrs["p"]) + name, input_nodes, attrs = get_inputs(node, kwargs) + + probability = float(attrs.get("p", 0.5)) dropout_node = onnx.helper.make_node( "Dropout", - [input_name], + input_nodes, [name], ratio=probability, name=name @@ -1082,37 +816,21 @@ def convert_flatten(node, **kwargs): """Map MXNet's Flatten operator attributes to onnx's Flatten operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - input_idx = kwargs["index_lookup"][node["inputs"][0][0]] - proc_nodes = kwargs["proc_nodes"] - input_node = proc_nodes[input_idx].name # .output[0] - - flatten_node = onnx.helper.make_node( - "Flatten", - [input_node], - [name], - name=name - ) - return [flatten_node] + return create_basic_op_node('Flatten', node, kwargs) @mx_op.register("clip") def convert_clip(node, **kwargs): """Map MXNet's Clip operator attributes to onnx's Clip operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - input_idx = kwargs["index_lookup"][node["inputs"][0][0]] - proc_nodes = kwargs["proc_nodes"] - input_node = proc_nodes[input_idx].name - attrs = node["attrs"] + name, input_nodes, attrs = get_inputs(node, kwargs) + a_min = np.float(attrs.get('a_min', -np.inf)) a_max = np.float(attrs.get('a_max', np.inf)) clip_node = onnx.helper.make_node( "Clip", - [input_node], + input_nodes, [name], name=name, min=a_min, @@ -1123,21 +841,16 @@ def convert_clip(node, **kwargs): def scalar_op_helper(node, op_name, **kwargs): """Helper function for scalar arithmetic operations""" - onnx = import_onnx_modules() - name = node["name"] - proc_nodes = kwargs["proc_nodes"] - inputs = node["inputs"] - scalar_value = [float(node.get("attrs", {}).get("scalar", 1))] + name, input_nodes, attrs = get_inputs(node, kwargs) - input_name_id = kwargs["index_lookup"][inputs[0][0]] - input_node = proc_nodes[input_name_id].name + scalar_value = [float(attrs.get("scalar", 1))] initializer = kwargs["initializer"] flag = True # If the input value is in initializer, just multiply with scalar input # and create a new initializer for i in initializer: - if i.name == input_node: + if i.name == input_nodes[0]: if op_name == 'Mul': new_initializer = onnx.numpy_helper.to_array(i) * scalar_value[0] elif op_name == 'Sub': @@ -1170,7 +883,7 @@ def scalar_op_helper(node, op_name, **kwargs): mul_node = onnx.helper.make_node( op_name, - [input_node, scalar_op_name], + [input_nodes[0], scalar_op_name], [name], name=name ) @@ -1180,7 +893,7 @@ def scalar_op_helper(node, op_name, **kwargs): data_type = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[new_initializer.dtype] dims = np.shape(new_initializer) - new_a_node = input_node + str(kwargs["idx"]) + new_a_node = input_nodes[0] + str(kwargs["idx"]) tensor_node = onnx.helper.make_tensor_value_info(new_a_node, data_type, dims) initializer.append( @@ -1239,21 +952,14 @@ def convert_argmax(node, **kwargs): """Map MXNet's argmax operator attributes to onnx's ArgMax operator and return the created node. """ - onnx = import_onnx_modules() - proc_nodes = kwargs["proc_nodes"] - node_inputs = node["inputs"] - - input_node_idx = kwargs["index_lookup"][node_inputs[0][0]] - input_node = proc_nodes[input_node_idx].name - name = node["name"] - attrs = node["attrs"] + name, input_nodes, attrs = get_inputs(node, kwargs) axis = int(attrs.get("axis")) keepdims = int(attrs.get("keepdims")) if "keepdims" in attrs else 1 node = onnx.helper.make_node( 'ArgMax', - inputs=[input_node], + inputs=input_nodes, axis=axis, keepdims=keepdims, outputs=[name], @@ -1266,21 +972,14 @@ def convert_argmin(node, **kwargs): """Map MXNet's argmin operator attributes to onnx's ArgMin operator and return the created node. """ - onnx = import_onnx_modules() - proc_nodes = kwargs["proc_nodes"] - node_inputs = node["inputs"] - - input_node_idx = kwargs["index_lookup"][node_inputs[0][0]] - input_node = proc_nodes[input_node_idx].name - name = node["name"] - attrs = node["attrs"] + name, input_nodes, attrs = get_inputs(node, kwargs) axis = int(attrs.get("axis")) keepdims = int(attrs.get("keepdims")) if "keepdims" in attrs else 1 node = onnx.helper.make_node( 'ArgMin', - inputs=[input_node], + inputs=input_nodes, axis=axis, keepdims=keepdims, outputs=[name], @@ -1293,25 +992,7 @@ def convert_maximum(node, **kwargs): """Map MXNet's _maximum operator attributes to onnx's Max operator and return the created node. """ - onnx = import_onnx_modules() - proc_nodes = kwargs["proc_nodes"] - node_inputs = node["inputs"] - - input_node_list = [] - for node_input in node_inputs: - node_id = kwargs["index_lookup"][node_input[0]] - input_node_list.append(proc_nodes[node_id].name) - - name = node["name"] - - node = onnx.helper.make_node( - 'Max', - inputs=input_node_list, - outputs=[name], - name=name, - ) - - return [node] + return create_basic_op_node('Max', node, kwargs) @mx_op.register("_minimum") @@ -1319,49 +1000,24 @@ def convert_minimum(node, **kwargs): """Map MXNet's _minimum operator attributes to onnx's Min operator and return the created node. """ - onnx = import_onnx_modules() - proc_nodes = kwargs["proc_nodes"] - node_inputs = node["inputs"] - - input_node_list = [] - for node_input in node_inputs: - node_id = kwargs["index_lookup"][node_input[0]] - input_node_list.append(proc_nodes[node_id].name) - - name = node["name"] - - node = onnx.helper.make_node( - 'Min', - inputs=input_node_list, - outputs=[name], - name=name, - ) - - return [node] - + return create_basic_op_node('Min', node, kwargs) @mx_op.register("min") def convert_min(node, **kwargs): """Map MXNet's min operator attributes to onnx's ReduceMin operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - proc_nodes = kwargs["proc_nodes"] - inputs = node["inputs"] + name, input_nodes, attrs = get_inputs(node, kwargs) - mx_axis = node.get("attrs", {}).get("axis", None) + mx_axis = attrs.get("axis", None) axes = convert_string_to_list(str(mx_axis)) if mx_axis is not None else None - keepdims = int(node.get("attrs", {}).get("keepdims", 0)) - - input_node_id = kwargs["index_lookup"][inputs[0][0]] - input_node = proc_nodes[input_node_id].name + keepdims = int(attrs.get("keepdims", 0)) if axes is not None: node = onnx.helper.make_node( 'ReduceMin', - inputs=[input_node], + inputs=input_nodes, outputs=[name], axes=axes, keepdims=keepdims, @@ -1372,7 +1028,7 @@ def convert_min(node, **kwargs): else: node = onnx.helper.make_node( 'ReduceMin', - inputs=[input_node], + inputs=input_nodes, outputs=[name], keepdims=keepdims, name=name @@ -1386,23 +1042,17 @@ def convert_max(node, **kwargs): """Map MXNet's max operator attributes to onnx's ReduceMax operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - proc_nodes = kwargs["proc_nodes"] - inputs = node["inputs"] + name, input_nodes, attrs = get_inputs(node, kwargs) - mx_axis = node.get("attrs", {}).get("axis", None) + mx_axis = attrs.get("axis", None) axes = convert_string_to_list(str(mx_axis)) if mx_axis is not None else None - keepdims = int(node.get("attrs", {}).get("keepdims", 0)) - - input_node_id = kwargs["index_lookup"][inputs[0][0]] - input_node = proc_nodes[input_node_id].name + keepdims = int(attrs.get("keepdims", 0)) if axes is not None: node = onnx.helper.make_node( 'ReduceMax', - inputs=[input_node], + inputs=input_nodes, outputs=[name], axes=axes, keepdims=keepdims, @@ -1413,7 +1063,7 @@ def convert_max(node, **kwargs): else: node = onnx.helper.make_node( 'ReduceMax', - inputs=[input_node], + inputs=input_nodes, outputs=[name], keepdims=keepdims, name=name @@ -1427,23 +1077,17 @@ def convert_mean(node, **kwargs): """Map MXNet's mean operator attributes to onnx's ReduceMean operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - proc_nodes = kwargs["proc_nodes"] - inputs = node["inputs"] + name, input_nodes, attrs = get_inputs(node, kwargs) - mx_axis = node.get("attrs", {}).get("axis", None) + mx_axis = attrs.get("axis", None) axes = convert_string_to_list(str(mx_axis)) if mx_axis is not None else None - keepdims = int(node.get("attrs", {}).get("keepdims", 0)) - - input_node_id = kwargs["index_lookup"][inputs[0][0]] - input_node = proc_nodes[input_node_id].name + keepdims = int(attrs.get("keepdims", 0)) if axes is not None: node = onnx.helper.make_node( 'ReduceMean', - inputs=[input_node], + inputs=input_nodes, outputs=[name], axes=axes, keepdims=keepdims, @@ -1454,7 +1098,7 @@ def convert_mean(node, **kwargs): else: node = onnx.helper.make_node( 'ReduceMean', - inputs=[input_node], + inputs=input_nodes, outputs=[name], keepdims=keepdims, name=name @@ -1468,23 +1112,17 @@ def convert_prod(node, **kwargs): """Map MXNet's prod operator attributes to onnx's ReduceProd operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - proc_nodes = kwargs["proc_nodes"] - inputs = node["inputs"] + name, input_nodes, attrs = get_inputs(node, kwargs) - mx_axis = node.get("attrs", {}).get("axis", None) + mx_axis = attrs.get("axis", None) axes = convert_string_to_list(str(mx_axis)) if mx_axis is not None else None - keepdims = int(node.get("attrs", {}).get("keepdims", 0)) - - input_node_id = kwargs["index_lookup"][inputs[0][0]] - input_node = proc_nodes[input_node_id].name + keepdims = int(attrs.get("keepdims", 0)) if axes is not None: node = onnx.helper.make_node( 'ReduceProd', - inputs=[input_node], + inputs=input_nodes, outputs=[name], axes=axes, keepdims=keepdims, @@ -1495,7 +1133,7 @@ def convert_prod(node, **kwargs): else: node = onnx.helper.make_node( 'ReduceProd', - inputs=[input_node], + inputs=input_nodes, outputs=[name], keepdims=keepdims, name=name @@ -1510,25 +1148,7 @@ def convert_elementwise_add(node, **kwargs): """Map MXNet's elemwise_add operator attributes to onnx's Add operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - proc_nodes = kwargs["proc_nodes"] - inputs = node["inputs"] - - input_node_a_id = kwargs["index_lookup"][inputs[0][0]] - input_node_b_id = kwargs["index_lookup"][inputs[1][0]] - - input_node_a = proc_nodes[input_node_a_id].name - input_node_b = proc_nodes[input_node_b_id].name - - add_node = onnx.helper.make_node( - "Add", - [input_node_a, input_node_b], - [name], - name=name, - ) - - return [add_node] + return create_basic_op_node('Add', node, kwargs) @mx_op.register("broadcast_add") @@ -1536,25 +1156,7 @@ def covert_broadcast_add(node, **kwargs): """Map MXNet's broadcast_add operator attributes to onnx's Add operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - proc_nodes = kwargs["proc_nodes"] - inputs = node["inputs"] - - input_node_a_id = kwargs["index_lookup"][inputs[0][0]] - input_node_b_id = kwargs["index_lookup"][inputs[1][0]] - - input_node_a = proc_nodes[input_node_a_id].name - input_node_b = proc_nodes[input_node_b_id].name - - add_node = onnx.helper.make_node( - "Add", - [input_node_a, input_node_b], - [name], - name=name, - ) - - return [add_node] + return create_basic_op_node('Add', node, kwargs) @mx_op.register("elemwise_sub") @@ -1562,224 +1164,63 @@ def convert_elementwise_sub(node, **kwargs): """Map MXNet's elemwise_sub operator attributes to onnx's Sub operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - proc_nodes = kwargs["proc_nodes"] - inputs = node["inputs"] - - input_node_a_id = kwargs["index_lookup"][inputs[0][0]] - input_node_b_id = kwargs["index_lookup"][inputs[1][0]] - - input_node_a = proc_nodes[input_node_a_id].name - input_node_b = proc_nodes[input_node_b_id].name - - sub_node = onnx.helper.make_node( - "Sub", - [input_node_a, input_node_b], - [name], - name=name, - ) - - return [sub_node] + return create_basic_op_node('Sub', node, kwargs) @mx_op.register("broadcast_sub") def covert_broadcast_sub(node, **kwargs): """Map MXNet's broadcast_sub operator attributes to onnx's Sub operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - proc_nodes = kwargs["proc_nodes"] - inputs = node["inputs"] - - input_node_a_id = kwargs["index_lookup"][inputs[0][0]] - input_node_b_id = kwargs["index_lookup"][inputs[1][0]] - - input_node_a = proc_nodes[input_node_a_id].name - input_node_b = proc_nodes[input_node_b_id].name - - sub_node = onnx.helper.make_node( - "Sub", - [input_node_a, input_node_b], - [name], - name=name, - ) - - return [sub_node] - + return create_basic_op_node('Sub', node, kwargs) @mx_op.register("elemwise_mul") def convert_elemwise_mul(node, **kwargs): """Map MXNet's elemwise_mul operator attributes to onnx's Mul operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - proc_nodes = kwargs["proc_nodes"] - inputs = node["inputs"] - - input_node_a_id = kwargs["index_lookup"][inputs[0][0]] - input_node_b_id = kwargs["index_lookup"][inputs[1][0]] - - input_node_a = proc_nodes[input_node_a_id].name - input_node_b = proc_nodes[input_node_b_id].name - - mul_node = onnx.helper.make_node( - "Mul", - [input_node_a, input_node_b], - [name], - name=name, - ) - - return [mul_node] + return create_basic_op_node('Mul', node, kwargs) @mx_op.register("broadcast_mul") def convert_broadcast_mul(node, **kwargs): """Map MXNet's broadcast_mul operator attributes to onnx's Mul operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - proc_nodes = kwargs["proc_nodes"] - inputs = node["inputs"] - - input_node_a_id = kwargs["index_lookup"][inputs[0][0]] - input_node_b_id = kwargs["index_lookup"][inputs[1][0]] - - input_node_a = proc_nodes[input_node_a_id].name - input_node_b = proc_nodes[input_node_b_id].name - - mul_node = onnx.helper.make_node( - "Mul", - [input_node_a, input_node_b], - [name], - name=name - ) - - return [mul_node] - + return create_basic_op_node('Mul', node, kwargs) @mx_op.register("elemwise_div") def convert_elemwise_div(node, **kwargs): """Map MXNet's elemwise_div operator attributes to onnx's Div operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - proc_nodes = kwargs["proc_nodes"] - inputs = node["inputs"] - - input_node_a_id = kwargs["index_lookup"][inputs[0][0]] - input_node_b_id = kwargs["index_lookup"][inputs[1][0]] - - input_node_a = proc_nodes[input_node_a_id].name - input_node_b = proc_nodes[input_node_b_id].name - - div_node = onnx.helper.make_node( - "Div", - [input_node_a, input_node_b], - [name], - name=name - ) - - return [div_node] - + return create_basic_op_node('Div', node, kwargs) @mx_op.register("broadcast_div") def convert_broadcast_div(node, **kwargs): """Map MXNet's broadcast_div operator attributes to onnx's Div operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - proc_nodes = kwargs["proc_nodes"] - inputs = node["inputs"] - - input_node_a_id = kwargs["index_lookup"][inputs[0][0]] - input_node_b_id = kwargs["index_lookup"][inputs[1][0]] - - input_node_a = proc_nodes[input_node_a_id].name - input_node_b = proc_nodes[input_node_b_id].name - - div_node = onnx.helper.make_node( - "Div", - [input_node_a, input_node_b], - [name], - name=name - ) - - return [div_node] - + return create_basic_op_node('Div', node, kwargs) @mx_op.register("negative") def convert_negative(node, **kwargs): """Map MXNet's negative operator attributes to onnx's Neg operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - proc_nodes = kwargs["proc_nodes"] - inputs = node["inputs"] - - input_node_id = kwargs["index_lookup"][inputs[0][0]] - - input_node = proc_nodes[input_node_id].name - - neg_node = onnx.helper.make_node( - "Neg", - [input_node], - [name], - name=name, - ) - - return [neg_node] - + return create_basic_op_node('Neg', node, kwargs) @mx_op.register("abs") def convert_abs(node, **kwargs): """Map MXNet's abs operator attributes to onnx's Abs operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - proc_nodes = kwargs["proc_nodes"] - inputs = node["inputs"] - - input_node_id = kwargs["index_lookup"][inputs[0][0]] - - input_node = proc_nodes[input_node_id].name - - abs_node = onnx.helper.make_node( - "Abs", - [input_node], - [name], - name=name - ) - - return [abs_node] - + return create_basic_op_node('Abs', node, kwargs) @mx_op.register("add_n") def convert_addn(node, **kwargs): """Map MXNet's add_n operator attributes to onnx's Sum operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - proc_nodes = kwargs["proc_nodes"] - inputs = node["inputs"] - - input_list = [] - for input_val in inputs: - input_list.append(proc_nodes[kwargs["index_lookup"][input_val[0]]].name) - - sum_node = onnx.helper.make_node( - "Sum", - input_list, - [name], - name=name - ) - return [sum_node] + return create_basic_op_node('Sum', node, kwargs) # Rounding @mx_op.register("ceil") @@ -1787,42 +1228,14 @@ def convert_ceil(node, **kwargs): """Map MXNet's ceil operator attributes to onnx's Ceil operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - proc_nodes = kwargs["proc_nodes"] - inputs = node["inputs"] - - input_node_id = kwargs["index_lookup"][inputs[0][0]] - input_node = proc_nodes[input_node_id].name - - node = onnx.helper.make_node( - "Ceil", - [input_node], - [name], - name=name - ) - return [node] + return create_basic_op_node('Ceil', node, kwargs) @mx_op.register("floor") def convert_floor(node, **kwargs): """Map MXNet's floor operator attributes to onnx's Floor operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - proc_nodes = kwargs["proc_nodes"] - inputs = node["inputs"] - - input_node_id = kwargs["index_lookup"][inputs[0][0]] - input_node = proc_nodes[input_node_id].name - - node = onnx.helper.make_node( - "Floor", - [input_node], - [name], - name=name - ) - return [node] + return create_basic_op_node('Floor', node, kwargs) # Changing shape and type. @mx_op.register("Reshape") @@ -1831,11 +1244,7 @@ def convert_reshape(node, **kwargs): Converts output shape attribute to output shape tensor and return multiple created nodes. """ - onnx = import_onnx_modules() - name = node["name"] - proc_nodes = kwargs["proc_nodes"] - inputs = node["inputs"] - attrs = node["attrs"] + name, input_nodes, attrs = get_inputs(node, kwargs) output_shape_list = convert_string_to_list(attrs["shape"]) @@ -1857,8 +1266,7 @@ def convert_reshape(node, **kwargs): ) ) - input_node_idx = kwargs["index_lookup"][inputs[0][0]] - input_node_name = proc_nodes[input_node_idx].name + input_nodes.append(output_shape_name) not_supported_shape = [-2, -3, -4] @@ -1868,7 +1276,7 @@ def convert_reshape(node, **kwargs): reshape_node = onnx.helper.make_node( "Reshape", - [input_node_name, output_shape_name], + input_nodes, [name], name=name ) @@ -1880,11 +1288,9 @@ def convert_cast(node, **kwargs): """Map MXNet's Cast operator attributes to onnx's Cast operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - proc_nodes = kwargs["proc_nodes"] - inputs = node["inputs"] - dtype = node["attrs"]["dtype"] + name, input_nodes, attrs = get_inputs(node, kwargs) + + dtype = attrs["dtype"] # dtype can be mapped only with types from TensorProto # float32 is mapped to float and float64 to double in onnx @@ -1894,12 +1300,9 @@ def convert_cast(node, **kwargs): elif dtype == 'float64': dtype = 'double' - input_node_id = kwargs["index_lookup"][inputs[0][0]] - input_node = proc_nodes[input_node_id].name - node = onnx.helper.make_node( "Cast", - [input_node], + input_nodes, [name], to=getattr(onnx.TensorProto, dtype.upper()), name=name, @@ -1912,23 +1315,17 @@ def convert_slice_axis(node, **kwargs): """Map MXNet's slice_axis operator attributes to onnx's Slice operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - proc_nodes = kwargs["proc_nodes"] - inputs = node["inputs"] - axes = int(node["attrs"]["axis"]) - starts = int(node["attrs"]["begin"]) - if node["attrs"]["end"] == 'None': - raise ValueError("Slice: ONNX doesnt't support 'None' in 'end' attribute") - else: - ends = int(node["attrs"]["end"]) + name, input_nodes, attrs = get_inputs(node, kwargs) - input_node_id = kwargs["index_lookup"][inputs[0][0]] - input_node = proc_nodes[input_node_id].name + axes = int(attrs.get("axis")) + starts = int(attrs.get("begin")) + ends = int(attrs.get("end", None)) + if not ends: + raise ValueError("Slice: ONNX doesnt't support 'None' in 'end' attribute") node = onnx.helper.make_node( "Slice", - [input_node], + input_nodes, [name], axes=[axes], starts=[starts], @@ -1944,21 +1341,16 @@ def convert_slice_channel(node, **kwargs): operator based on squeeze_axis attribute and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - proc_nodes = kwargs["proc_nodes"] - inputs = node["inputs"] - num_outputs = int(node.get("attrs", {})["num_outputs"]) - axis = int(node.get("attrs", {}).get("axis", 1)) - squeeze_axis = int(node.get("attrs", {}).get("squeeze_axis", 0)) + name, input_nodes, attrs = get_inputs(node, kwargs) - input_node_id = kwargs["index_lookup"][inputs[0][0]] - input_node = proc_nodes[input_node_id].name + num_outputs = int(attrs.get("num_outputs")) + axis = int(attrs.get("axis", 1)) + squeeze_axis = int(attrs.get("squeeze_axis", 0)) if squeeze_axis == 1 and num_outputs == 1: node = onnx.helper.make_node( "Squeeze", - [input_node], + input_nodes, [name], axes=[axis], name=name, @@ -1967,7 +1359,7 @@ def convert_slice_channel(node, **kwargs): elif squeeze_axis == 0 and num_outputs > 1: node = onnx.helper.make_node( "Split", - [input_node], + input_nodes, [name], axis=axis, split=[num_outputs], @@ -1984,18 +1376,13 @@ def convert_expand_dims(node, **kwargs): """Map MXNet's expand_dims operator attributes to onnx's Unsqueeze operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - proc_nodes = kwargs["proc_nodes"] - inputs = node["inputs"] - axis = int(node["attrs"]["axis"]) + name, input_nodes, attrs = get_inputs(node, kwargs) - input_node_id = kwargs["index_lookup"][inputs[0][0]] - input_node = proc_nodes[input_node_id].name + axis = int(attrs.get("axis")) node = onnx.helper.make_node( "Unsqueeze", - [input_node], + input_nodes, [name], axes=[axis], name=name, @@ -2007,22 +1394,17 @@ def convert_squeeze(node, **kwargs): """Map MXNet's squeeze operator attributes to onnx's squeeze operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - proc_nodes = kwargs["proc_nodes"] - inputs = node["inputs"] - if "axis" in node["attrs"]: - axis = convert_string_to_list(node["attrs"]["axis"]) - else: + name, input_nodes, attrs = get_inputs(node, kwargs) + + axis = attrs.get("axis", None) + if not axis: raise AttributeError("Missing axis attribute: ONNX currently requires axis to " "be specified for squeeze operator") - - input_node_id = kwargs["index_lookup"][inputs[0][0]] - input_node = proc_nodes[input_node_id].name + axis = convert_string_to_list(axis) node = onnx.helper.make_node( "Squeeze", - [input_node], + input_nodes, [name], axes=axis, name=name, @@ -2035,132 +1417,48 @@ def convert_log(node, **kwargs): """Map MXNet's log operator attributes to onnx's Log operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - proc_nodes = kwargs["proc_nodes"] - inputs = node["inputs"] - - input_node_id = kwargs["index_lookup"][inputs[0][0]] - input_node = proc_nodes[input_node_id].name - - node = onnx.helper.make_node( - "Log", - [input_node], - [name], - name=name, - ) - return [node] - + return create_basic_op_node('Log', node, kwargs) @mx_op.register("reciprocal") def convert_reciprocal(node, **kwargs): """Map MXNet's reciprocal operator attributes to onnx's Reciprocal operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - proc_nodes = kwargs["proc_nodes"] - inputs = node["inputs"] - - input_node_id = kwargs["index_lookup"][inputs[0][0]] - input_node = proc_nodes[input_node_id].name - - node = onnx.helper.make_node( - "Reciprocal", - [input_node], - [name], - name=name, - ) - return [node] + return create_basic_op_node('Reciprocal', node, kwargs) @mx_op.register("_power") def convert_power(node, **kwargs): """Map MXNet's _power operator attributes to onnx's Pow operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - proc_nodes = kwargs["proc_nodes"] - inputs = node["inputs"] - - input_node_a_id = kwargs["index_lookup"][inputs[0][0]] - input_node_b_id = kwargs["index_lookup"][inputs[1][0]] - - input_node_a = proc_nodes[input_node_a_id].name - input_node_b = proc_nodes[input_node_b_id].name - - node = onnx.helper.make_node( - "Pow", - [input_node_a, input_node_b], - [name], - name=name - ) - return [node] + return create_basic_op_node('Pow', node, kwargs) @mx_op.register("broadcast_power") def convert_broadcast_power(node, **kwargs): """Map MXNet's _power operator attributes to onnx's Pow operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - proc_nodes = kwargs["proc_nodes"] - inputs = node["inputs"] - - input_node_a_id = kwargs["index_lookup"][inputs[0][0]] - input_node_b_id = kwargs["index_lookup"][inputs[1][0]] - - input_node_a = proc_nodes[input_node_a_id].name - input_node_b = proc_nodes[input_node_b_id].name - - node = onnx.helper.make_node( - "Pow", - [input_node_a, input_node_b], - [name], - name=name - ) - return [node] + return create_basic_op_node('Pow', node, kwargs) @mx_op.register("sqrt") def convert_sqrt(node, **kwargs): """Map MXNet's sqrt operator attributes to onnx's Sqrt operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - proc_nodes = kwargs["proc_nodes"] - inputs = node["inputs"] - - input_node_id = kwargs["index_lookup"][inputs[0][0]] - input_node = proc_nodes[input_node_id].name - - node = onnx.helper.make_node( - "Sqrt", - [input_node], - [name], - name=name, - ) - return [node] + return create_basic_op_node('Sqrt', node, kwargs) @mx_op.register("depth_to_space") def convert_depthtospace(node, **kwargs): """Map MXNet's depth_to_space operator attributes to onnx's DepthToSpace operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - proc_nodes = kwargs["proc_nodes"] - inputs = node["inputs"] - attrs = node["attrs"] - - input_node_id = kwargs["index_lookup"][inputs[0][0]] - input_node = proc_nodes[input_node_id].name + name, input_nodes, attrs = get_inputs(node, kwargs) blksize = int(attrs.get("block_size", 0)) node = onnx.helper.make_node( "DepthToSpace", - [input_node], + input_nodes, [name], blocksize=blksize, name=name, @@ -2172,20 +1470,13 @@ def convert_spacetodepth(node, **kwargs): """Map MXNet's space_to_depth operator attributes to onnx's SpaceToDepth operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - proc_nodes = kwargs["proc_nodes"] - inputs = node["inputs"] - attrs = node["attrs"] - - input_node_id = kwargs["index_lookup"][inputs[0][0]] - input_node = proc_nodes[input_node_id].name + name, input_nodes, attrs = get_inputs(node, kwargs) blksize = int(attrs.get("block_size", 0)) node = onnx.helper.make_node( "SpaceToDepth", - [input_node], + input_nodes, [name], blocksize=blksize, name=name, @@ -2197,13 +1488,7 @@ def convert_square(node, **kwargs): """Map MXNet's square operator attributes to onnx's Pow operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - proc_nodes = kwargs["proc_nodes"] - inputs = node["inputs"] - - input_node_a_id = kwargs["index_lookup"][inputs[0][0]] - input_node_a = proc_nodes[input_node_a_id].name + name, input_nodes, _ = get_inputs(node, kwargs) initializer = kwargs["initializer"] data_type = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype('int64')] @@ -2220,9 +1505,11 @@ def convert_square(node, **kwargs): ) ) + input_nodes.append(power2_name) + node = onnx.helper.make_node( "Pow", - [input_node_a, power2_name], + input_nodes, [name], name=name ) @@ -2233,24 +1520,17 @@ def convert_sum(node, **kwargs): """Map MXNet's sum operator attributes to onnx's ReduceSum operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - proc_nodes = kwargs["proc_nodes"] - inputs = node["inputs"] - attrs = node["attrs"] + name, input_nodes, attrs = get_inputs(node, kwargs) mx_axis = attrs.get("axis", None) axes = convert_string_to_list(str(mx_axis)) if mx_axis is not None else None keepdims = get_boolean_attribute_value(attrs, "keepdims") - input_node_id = kwargs["index_lookup"][inputs[0][0]] - input_node = proc_nodes[input_node_id].name - if axes: node = onnx.helper.make_node( 'ReduceSum', - inputs=[input_node], + inputs=input_nodes, outputs=[name], axes=axes, keepdims=keepdims, @@ -2259,7 +1539,7 @@ def convert_sum(node, **kwargs): else: node = onnx.helper.make_node( 'ReduceSum', - inputs=[input_node], + inputs=input_nodes, outputs=[name], keepdims=keepdims, name=name diff --git a/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py b/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py index 11847381ab24..b02d970f9c2d 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py +++ b/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py @@ -77,7 +77,11 @@ def register(op_name): """Register operators""" def wrapper(func): """Helper function to map functions""" - MXNetGraph.registry_[op_name] = func + try: + import onnx as _ + MXNetGraph.registry_[op_name] = func + except ImportError: + pass return func return wrapper