From fbb96ba12434ad25e4a7e1716630f35008bf574f Mon Sep 17 00:00:00 2001 From: zha0q1 Date: Thu, 6 May 2021 01:02:44 +0000 Subject: [PATCH 1/6] optimization for bertgit status --- python/mxnet/onnx/mx2onnx/__init__.py | 3 +- python/mxnet/onnx/mx2onnx/_export_model.py | 10 +- python/mxnet/onnx/mx2onnx/_export_onnx.py | 25 +- .../onnx/mx2onnx/_op_translations/__init__.py | 3 +- .../_op_translations/_gluonnlp_bert.py | 254 ++++++++++++++++++ 5 files changed, 284 insertions(+), 11 deletions(-) create mode 100644 python/mxnet/onnx/mx2onnx/_op_translations/_gluonnlp_bert.py diff --git a/python/mxnet/onnx/mx2onnx/__init__.py b/python/mxnet/onnx/mx2onnx/__init__.py index 339d74dd7c2d..2573958a6403 100644 --- a/python/mxnet/onnx/mx2onnx/__init__.py +++ b/python/mxnet/onnx/mx2onnx/__init__.py @@ -19,5 +19,4 @@ """ONNX Export module""" from ._export_model import export_model, get_operator_support -from ._op_translations import _op_translations_opset12 -from ._op_translations import _op_translations_opset13 +from ._op_translations import * diff --git a/python/mxnet/onnx/mx2onnx/_export_model.py b/python/mxnet/onnx/mx2onnx/_export_model.py index ad33c2aec7c8..bfd109d8931c 100644 --- a/python/mxnet/onnx/mx2onnx/_export_model.py +++ b/python/mxnet/onnx/mx2onnx/_export_model.py @@ -51,7 +51,7 @@ def get_operator_support(opset_version=None): def export_model(sym, params, in_shapes=None, in_types=np.float32, onnx_file_path='model.onnx', verbose=False, dynamic=False, dynamic_input_shapes=None, run_shape_inference=False, input_type=None, - input_shape=None): + input_shape=None, model_specific_logics=None): """Exports the MXNet model file, passed as a parameter, into ONNX model. Accepts both symbol,parameter objects as well as json and params filepaths as input. Operator support and coverage - @@ -83,6 +83,8 @@ def export_model(sym, params, in_shapes=None, in_types=np.float32, This is the old name of in_types. We keep this parameter name for backward compatibility in_shapes : List of tuple This is the old name of in_shapes. We keep this parameter name for backward compatibility + model_specific_logics: str + Specifies if model-specific conversion logic should be used. Refer to ./_op_translations/ Returns ------- @@ -122,12 +124,14 @@ def export_model(sym, params, in_shapes=None, in_types=np.float32, onnx_graph = converter.create_onnx_graph_proto(sym_obj, params_obj, in_shapes, in_types_t, verbose=verbose, opset_version=opset_version, - dynamic=dynamic, dynamic_input_shapes=dynamic_input_shapes) + dynamic=dynamic, dynamic_input_shapes=dynamic_input_shapes, + model_specific_logics=model_specific_logics) elif isinstance(sym, symbol.Symbol) and isinstance(params, dict): onnx_graph = converter.create_onnx_graph_proto(sym, params, in_shapes, in_types_t, verbose=verbose, opset_version=opset_version, - dynamic=dynamic, dynamic_input_shapes=dynamic_input_shapes) + dynamic=dynamic, dynamic_input_shapes=dynamic_input_shapes, + model_specific_logics=model_specific_logics) elif isinstance(sym, symbol.Symbol) and isinstance(params, list) and len(params) == 2: # when params contains arg_params and aux_params p = {} diff --git a/python/mxnet/onnx/mx2onnx/_export_onnx.py b/python/mxnet/onnx/mx2onnx/_export_onnx.py index 307095beff09..454d6eed14cc 100644 --- a/python/mxnet/onnx/mx2onnx/_export_onnx.py +++ b/python/mxnet/onnx/mx2onnx/_export_onnx.py @@ -92,14 +92,25 @@ def convert_layer(node, **kwargs): op = str(node["op"]) opset_version = kwargs.get("opset_version", onnx_opset_version()) # fallback to older opset versions if op is not registered in current version + convert_func = None for op_version in range(opset_version, 11, -1): if op_version not in MXNetGraph.registry_ or op not in MXNetGraph.registry_[op_version]: - if opset_version == 12: - raise AttributeError("No conversion function registered for op type %s yet." % op) continue convert_func = MXNetGraph.registry_[op_version][op] break + model_specific_logics = kwargs.get("model_specific_logics", None) + if model_specific_logics: + assert model_specific_logics in MXNetGraph.registry_,\ + "Model specific converion logics for %s is not found" % model_specific_logics + if op in MXNetGraph.registry_[model_specific_logics]: + logging.info("Found model-specific %s conversion logic for model %s", + op, model_specific_logics) + convert_func = MXNetGraph.registry_[model_specific_logics][op] + + if convert_func is None: + raise AttributeError("No conversion function registered for op type %s yet." % op) + ret = convert_func(node, **kwargs) # in case the conversion function does not specify the returned dtype, we just return None # as the second value @@ -239,7 +250,7 @@ def convert_weights_to_numpy(weights_dict): for k, v in weights_dict.items()]) def create_onnx_graph_proto(self, sym, params, in_shapes, in_types, verbose=False, opset_version=None, - dynamic=True, dynamic_input_shapes=None): + dynamic=True, dynamic_input_shapes=None, model_specific_logics=None): """Convert MXNet graph to ONNX graph Parameters @@ -260,6 +271,8 @@ def create_onnx_graph_proto(self, sym, params, in_shapes, in_types, verbose=Fals If True will allow for dynamic input shapes to the model dynamic_input_shapes: list of tuple Specifies the dynamic input_shapes. If None then all dimensions are set to None + model_specific_logics: str + Specifies if model-specific conversion logic should be used. Refer to ./_op_translations/ Returns ------- @@ -335,7 +348,8 @@ def __init__(self, name, dtype): in_type=in_types[graph_input_idx], proc_nodes=all_processed_nodes, initializer=initializer, - outputs_lookup=outputs_lookup) + outputs_lookup=outputs_lookup, + ) graph_input_idx += 1 else: # Handle graph layers @@ -348,7 +362,8 @@ def __init__(self, name, dtype): initializer=initializer, outputs_lookup=outputs_lookup, idx=idx, - opset_version=opset_version + opset_version=opset_version, + model_specific_logics=model_specific_logics ) if isinstance(converted, list): # Collect all the node's output names diff --git a/python/mxnet/onnx/mx2onnx/_op_translations/__init__.py b/python/mxnet/onnx/mx2onnx/_op_translations/__init__.py index ba26e207eb4c..5d1bc5019e7b 100644 --- a/python/mxnet/onnx/mx2onnx/_op_translations/__init__.py +++ b/python/mxnet/onnx/mx2onnx/_op_translations/__init__.py @@ -18,5 +18,6 @@ # coding: utf-8 """ONNX export op translation""" +from . import _gluonnlp_bert from . import _op_translations_opset12 -from . import _op_translations_opset13 +from . import _op_translations_opset13 diff --git a/python/mxnet/onnx/mx2onnx/_op_translations/_gluonnlp_bert.py b/python/mxnet/onnx/mx2onnx/_op_translations/_gluonnlp_bert.py new file mode 100644 index 000000000000..9dfd40888a67 --- /dev/null +++ b/python/mxnet/onnx/mx2onnx/_op_translations/_gluonnlp_bert.py @@ -0,0 +1,254 @@ + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# coding: utf-8 +"""GluonNLP BERT specific translation logics""" + + +import re +import logging +import numpy as np +from .._export_onnx import MXNetGraph as mx_op +try: + import onnx +except ImportError: + onnx = None + + +def get_cheat_sheet(): + cheat_sheet = { + 'qkv_hidden': 768, + 'num_heads': 12, + 'head_dim': 64 + } + + return cheat_sheet + + +def get_inputs(node, kwargs): + """Helper function to get inputs""" + name = node["name"] + outputs_lookup = kwargs["outputs_lookup"] + inputs = node["inputs"] + attrs = node.get("attrs", {}) + + input_nodes = [] + for ip in inputs: + input_node_name = outputs_lookup[ip[0]][ip[1]].name + input_nodes.append(input_node_name) + + return name, input_nodes, attrs + + +def get_input_dtypes(node, kwargs): + outputs_lookup = kwargs['outputs_lookup'] + inputs = node['inputs'] + input_dtypes = [] + for ip in inputs: + input_node_dtype = outputs_lookup[ip[0]][ip[1]].dtype + input_dtypes.append(input_node_dtype) + return input_dtypes + + +def get_boolean_attribute_value(attrs, attr_name): + """ Helper function to convert a string version + of Boolean attributes to integer for ONNX. + Takes attribute dictionary and attr_name as + parameters. + """ + return 1 if attrs.get(attr_name, 0) in ["True", "1"] else 0 + + +def create_tensor(tensor_list, tensor_name, initializer, dtype='int64'): + """Helper function to create a tensor value node and a + initializer tensor node with constant value.""" + tensor_np = np.array(tensor_list, dtype=dtype) + data_type = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[tensor_np.dtype] + dims = np.shape(tensor_np) + if dtype == np.float16: + tensor_np = tensor_np.view(dtype=np.uint16) + initializer.append( + onnx.helper.make_tensor( + name=tensor_name, + data_type=data_type, + dims=dims, + vals=tensor_np.flatten().tolist(), + raw=False + ) + ) + + +@mx_op.register("FullyConnected", opset_version='gluonnlp_bert') +def convert_fully_connected(node, **kwargs): + """Map MXNet's FullyConnected operator attributes to onnx's Gemm operator + and return the created node. + """ + from onnx.helper import make_node + name, input_nodes, attrs = get_inputs(node, kwargs) + input_dtypes = get_input_dtypes(node, kwargs) + dtype = input_dtypes[0] + flatten = get_boolean_attribute_value(attrs, 'flatten') + no_bias = get_boolean_attribute_value(attrs, 'no_bias') + num_hidden = int(attrs.get('num_hidden')) + + nodes = [] + + if 'dotproductselfattentioncell' in name: + cheat_sheet = get_cheat_sheet() + qkv_hidden = cheat_sheet['qkv_hidden'] + num_heads = cheat_sheet['num_heads'] + head_dim = cheat_sheet['head_dim'] + create_tensor([num_heads, 3 * head_dim, qkv_hidden], name+'_interleaved_w_shape', kwargs['initializer']) + create_tensor([num_heads, 3 * head_dim], name+'_interleaved_b_shape', kwargs['initializer']) + create_tensor([qkv_hidden, qkv_hidden], name+'_w_shape', kwargs['initializer']) + create_tensor([qkv_hidden], name+'_b_shape', kwargs['initializer']) + nodes += [ + make_node('Reshape', [input_nodes[1], name+'_interleaved_w_shape'], [name+'_interleaved_w']), + make_node('Split', [name+'_interleaved_w'], [name+'_q_w_', name+'_k_w_', name+'_v_w_'], axis=1), + make_node('Reshape', [name+'_q_w_', name+'_w_shape'], [name+'_q_w_reshaped']), + make_node('Reshape', [name+'_k_w_', name+'_w_shape'], [name+'_k_w_reshaped']), + make_node('Reshape', [name+'_v_w_', name+'_w_shape'], [name+'_v_w_reshaped']), + make_node('Transpose', [name+'_q_w_reshaped'], [name+'_q_w']), + make_node('Transpose', [name+'_k_w_reshaped'], [name+'_k_w']), + make_node('Transpose', [name+'_v_w_reshaped'], [name+'_v_w']), + make_node('Reshape', [input_nodes[2], name+'_interleaved_b_shape'], [name+'_interleaved_b']), + make_node('Split', [name+'_interleaved_b'], [name+'_q_b_', name+'_k_b_', name+'_v_b_'], axis=1), + make_node('Reshape', [name+'_q_b_', name+'_b_shape'], [name+'_q_b']), + make_node('Reshape', [name+'_k_b_', name+'_b_shape'], [name+'_k_b']), + make_node('Reshape', [name+'_v_b_', name+'_b_shape'], [name+'_v_b']), + make_node('MatMul', [input_nodes[0], name+'_q_w'], [name+'_q_']), + make_node('MatMul', [input_nodes[0], name+'_k_w'], [name+'_k_']), + make_node('MatMul', [input_nodes[0], name+'_v_w'], [name+'_v_']), + make_node('Add', [name+'_q_', name+'_q_b'], [name+'0']), + make_node('Add', [name+'_k_', name+'_k_b'], [name+'1']), + make_node('Add', [name+'_v_', name+'_v_b'], [name+'2']), + ] + return nodes + + if flatten: + nodes += [ + make_node('Flatten', [input_nodes[0]], [name+'_data_flattened']) + ] + else: + nodes += [ + make_node('Shape', [input_nodes[0]], [name+'_orig_shape']), + make_node('Shape', [name+'_orig_shape'], [name+'_dim']), + make_node('Flatten', [input_nodes[0]], [name+'_data_flattened'], axis=-1), + ] + in_nodes = [name+'_data_flattened', input_nodes[1]] + if no_bias: + create_const_scalar_node(name+'_bias', np.int32(0).astype(dtype), kwargs) + in_nodes.append(name+'_bias') + else: + in_nodes.append(input_nodes[2]) + if flatten: + nodes += [ + make_node('Gemm', in_nodes, [name], alpha=1.0, beta=1.0, transA=0, transB=1, name=name) + ] + else: + create_tensor([0], name+'_0', kwargs['initializer']) + create_tensor([1], name+'_1', kwargs['initializer']) + create_tensor([num_hidden], name+'_num_hidden', kwargs['initializer']) + nodes += [ + make_node('Gemm', in_nodes, [name+'_gemm'], alpha=1.0, beta=1.0, transA=0, transB=1), + make_node('Sub', [name+'_dim', name+'_1'], [name+'dim_minus_1']), + make_node('Slice', [name+'_orig_shape', name+'_0', name+'dim_minus_1'], + [name+'_shape_sliced']), + make_node('Concat', [name+'_shape_sliced', name+'_num_hidden'], + [name+'_shape_new'], axis=0), + make_node('Reshape', [name+'_gemm', name+'_shape_new'], [name], name=name) + ] + return nodes + + +@mx_op.register("_contrib_interleaved_matmul_selfatt_qk", opset_version='gluonnlp_bert') +def convert_matmul_selfatt_qk(node, **kwargs): + """Map MXNet's _contrib_interleaved_matmul_selfatt_qk operator + """ + from onnx.helper import make_node + from onnx import TensorProto + import copy + + inp0 = node['inputs'][0] + inp1 = copy.deepcopy(inp0) + inp1[1] = 1 + node['inputs'] = [inp0, inp1] + name, input_nodes, attrs = get_inputs(node, kwargs) + + cheat_sheet = get_cheat_sheet() + qkv_hidden = cheat_sheet['qkv_hidden'] + num_heads = cheat_sheet['num_heads'] + head_dim = cheat_sheet['head_dim'] + + create_tensor([-1], name+'_m1', kwargs['initializer']) + create_tensor([int(head_dim ** 0.5)], name+'_sqrt_head_dim', kwargs['initializer'], dtype='float32') + create_tensor([0, 0, num_heads, head_dim], name+"_qkv_shape", kwargs['initializer']) + nodes = [ + make_node('Shape', [input_nodes[0]], [name+'_shape']), + make_node('Split', [name+'_shape'], [name+'_sq', name+'_bs', name+'___'], axis=0), + make_node('Reshape', [input_nodes[0], name+'_qkv_shape'], [name+'_q_']), + make_node('Reshape', [input_nodes[1], name+'_qkv_shape'], [name+'_k_']), + make_node('Transpose', [name+'_q_'], [name+'_q'], perm=[1, 2, 0, 3]), + make_node('Transpose', [name+'_k_'], [name+'_k'], perm=[1, 2, 3, 0]), + make_node('MatMul', [name+'_q', name+'_k'], [name+'_qk']), + make_node('Concat', [name+'_m1', name+'_sq', name+'_sq'], [name+'_out_shape'], axis=0), + make_node('Reshape', [name+'_qk', name+'_out_shape'], [name+'_qk_reshaped']), + make_node('Div', [name+'_qk_reshaped', name+'_sqrt_head_dim'], [name]) + ] + + return nodes + + +@mx_op.register("_contrib_interleaved_matmul_selfatt_valatt", opset_version='gluonnlp_bert') +def convert_contrib_interleaved_matmul_selfatt_valatt(node, **kwargs): + """Map MXNet's _contrib_interleaved_matmul_selfatt_valatt operator attributes to onnx's operator. + """ + from onnx.helper import make_node + inp0 = node['inputs'][0] + inp0[1] = 2 + inp1 = node['inputs'][1] + node['inputs'] = [inp0, inp1] + name, input_nodes, attrs = get_inputs(node, kwargs) + + cheat_sheet = get_cheat_sheet() + qkv_hidden = cheat_sheet['qkv_hidden'] + num_heads = cheat_sheet['num_heads'] + head_dim = cheat_sheet['head_dim'] + + create_tensor([head_dim], name+'_head_dim', kwargs["initializer"]) + create_tensor([0], name+'_0', kwargs["initializer"]) + create_tensor([-1], name+'_m1', kwargs["initializer"]) + create_tensor([0, 0, num_heads, head_dim], name+"_qkv_shape", kwargs["initializer"]) + create_tensor([0, 0, -1], name+"_out_shape", kwargs["initializer"]) + nodes = [ + make_node('Shape', [input_nodes[0]], [name+'_shape']), + make_node('Split', [name+'_shape'], [name+'_sq', name+'_bs', name+'___'], axis=0), + make_node('Reshape', [input_nodes[0], name+"_qkv_shape"], [name+'_v__']), + make_node('Transpose', [name+'_v__'], [name+'_v_'], perm=[1, 2, 0, 3]), + make_node('Concat', [name+'_m1', name+'_sq', name+'_head_dim'], [name+'_v_shape'], axis=0), + make_node('Reshape', [name+'_v_', name+'_v_shape'], [name+'_v']), + make_node('MatMul', [input_nodes[1], name+'_v'], [name+'_matmul']), + make_node('Concat', [name+'_bs', name+'_m1', name+'_sq', name+'_head_dim'], + [name+'_before_transpose'], axis=0), + make_node('Reshape', [name+'_matmul', name+'_before_transpose'], [name+'_bt']), + make_node('Transpose', [name+'_bt'], [name+'_transpose'], perm=[2, 0, 1, 3]), + make_node('Reshape', [name+'_transpose', name+'_out_shape'], [name]) + ] + + return nodes + From 86af75dfa7e302d32dffdb59675c7d7abe64c934 Mon Sep 17 00:00:00 2001 From: zha0q1 Date: Thu, 6 May 2021 04:54:25 +0000 Subject: [PATCH 2/6] fix sanity --- .../_op_translations/_gluonnlp_bert.py | 40 ++++++++++++++----- 1 file changed, 30 insertions(+), 10 deletions(-) diff --git a/python/mxnet/onnx/mx2onnx/_op_translations/_gluonnlp_bert.py b/python/mxnet/onnx/mx2onnx/_op_translations/_gluonnlp_bert.py index 9dfd40888a67..d10e7936a43b 100644 --- a/python/mxnet/onnx/mx2onnx/_op_translations/_gluonnlp_bert.py +++ b/python/mxnet/onnx/mx2onnx/_op_translations/_gluonnlp_bert.py @@ -18,10 +18,6 @@ # coding: utf-8 """GluonNLP BERT specific translation logics""" - - -import re -import logging import numpy as np from .._export_onnx import MXNetGraph as mx_op try: @@ -74,6 +70,34 @@ def get_boolean_attribute_value(attrs, attr_name): return 1 if attrs.get(attr_name, 0) in ["True", "1"] else 0 +def create_const_scalar_node(input_name, value, kwargs): + """Helper function to create a tensor value node and a + initializer tensor node with constant value.""" + from onnx.helper import make_tensor + initializer = kwargs["initializer"] + dtype = value.dtype + if dtype == 'float16': + # when using float16, we must convert it to np.uint16 view first + value = np.float16(value).view(np.uint16) + input_type = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[dtype] + tensor_node = make_tensor(input_name, input_type, (), ([value])) + initializer.append(tensor_node) + +def create_const_node(input_name, value, kwargs): + """Helper function to create a tensor value node and a + initializer tensor node with constant value.""" + from onnx.helper import make_tensor + initializer = kwargs["initializer"] + dtype = value.dtype + if dtype == 'float16': + # when using float16, we must convert it to np.uint16 view first + value = np.float16(value).view(np.uint16) + input_type = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[dtype] + input_shape = value.shape + tensor_node = make_tensor(input_name, input_type, input_shape, value) + initializer.append(tensor_node) + + def create_tensor(tensor_list, tensor_name, initializer, dtype='int64'): """Helper function to create a tensor value node and a initializer tensor node with constant value.""" @@ -181,17 +205,15 @@ def convert_matmul_selfatt_qk(node, **kwargs): """Map MXNet's _contrib_interleaved_matmul_selfatt_qk operator """ from onnx.helper import make_node - from onnx import TensorProto import copy inp0 = node['inputs'][0] inp1 = copy.deepcopy(inp0) inp1[1] = 1 node['inputs'] = [inp0, inp1] - name, input_nodes, attrs = get_inputs(node, kwargs) + name, input_nodes, _ = get_inputs(node, kwargs) cheat_sheet = get_cheat_sheet() - qkv_hidden = cheat_sheet['qkv_hidden'] num_heads = cheat_sheet['num_heads'] head_dim = cheat_sheet['head_dim'] @@ -223,10 +245,9 @@ def convert_contrib_interleaved_matmul_selfatt_valatt(node, **kwargs): inp0[1] = 2 inp1 = node['inputs'][1] node['inputs'] = [inp0, inp1] - name, input_nodes, attrs = get_inputs(node, kwargs) + name, input_nodes, _ = get_inputs(node, kwargs) cheat_sheet = get_cheat_sheet() - qkv_hidden = cheat_sheet['qkv_hidden'] num_heads = cheat_sheet['num_heads'] head_dim = cheat_sheet['head_dim'] @@ -251,4 +272,3 @@ def convert_contrib_interleaved_matmul_selfatt_valatt(node, **kwargs): ] return nodes - From f741272b4093ae196b04a04cbb5616363ee2ebb8 Mon Sep 17 00:00:00 2001 From: zha0q1 Date: Thu, 6 May 2021 05:09:45 +0000 Subject: [PATCH 3/6] fix sanity --- python/mxnet/onnx/mx2onnx/_op_translations/__init__.py | 2 +- python/mxnet/onnx/mx2onnx/_op_translations/_gluonnlp_bert.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/python/mxnet/onnx/mx2onnx/_op_translations/__init__.py b/python/mxnet/onnx/mx2onnx/_op_translations/__init__.py index 5d1bc5019e7b..f3c8cbd889b1 100644 --- a/python/mxnet/onnx/mx2onnx/_op_translations/__init__.py +++ b/python/mxnet/onnx/mx2onnx/_op_translations/__init__.py @@ -20,4 +20,4 @@ from . import _gluonnlp_bert from . import _op_translations_opset12 -from . import _op_translations_opset13 +from . import _op_translations_opset13 diff --git a/python/mxnet/onnx/mx2onnx/_op_translations/_gluonnlp_bert.py b/python/mxnet/onnx/mx2onnx/_op_translations/_gluonnlp_bert.py index d10e7936a43b..51186f974363 100644 --- a/python/mxnet/onnx/mx2onnx/_op_translations/_gluonnlp_bert.py +++ b/python/mxnet/onnx/mx2onnx/_op_translations/_gluonnlp_bert.py @@ -78,7 +78,7 @@ def create_const_scalar_node(input_name, value, kwargs): dtype = value.dtype if dtype == 'float16': # when using float16, we must convert it to np.uint16 view first - value = np.float16(value).view(np.uint16) + value = np.float16(value).view(np.uint16) #pylint: disable=too-many-function-args input_type = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[dtype] tensor_node = make_tensor(input_name, input_type, (), ([value])) initializer.append(tensor_node) @@ -91,7 +91,7 @@ def create_const_node(input_name, value, kwargs): dtype = value.dtype if dtype == 'float16': # when using float16, we must convert it to np.uint16 view first - value = np.float16(value).view(np.uint16) + value = np.float16(value).view(np.uint16) #pylint: disable=too-many-function-args input_type = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[dtype] input_shape = value.shape tensor_node = make_tensor(input_name, input_type, input_shape, value) From 259743ce9e9f035fe232953ae73272ec4e8f5d07 Mon Sep 17 00:00:00 2001 From: zha0q1 Date: Thu, 6 May 2021 21:14:19 +0000 Subject: [PATCH 4/6] add test case --- python/mxnet/onnx/mx2onnx/_export_model.py | 8 ++- python/mxnet/onnx/mx2onnx/_export_onnx.py | 8 ++- .../onnx/mx2onnx/_op_translations/__init__.py | 2 +- ...ert.py => _gluonnlp_bert_uninterleaved.py} | 30 ++++----- .../onnx/test_onnxruntime_nlp.py | 62 ++++++++++++++++--- 5 files changed, 79 insertions(+), 31 deletions(-) rename python/mxnet/onnx/mx2onnx/_op_translations/{_gluonnlp_bert.py => _gluonnlp_bert_uninterleaved.py} (95%) diff --git a/python/mxnet/onnx/mx2onnx/_export_model.py b/python/mxnet/onnx/mx2onnx/_export_model.py index bfd109d8931c..3d54c78e0958 100644 --- a/python/mxnet/onnx/mx2onnx/_export_model.py +++ b/python/mxnet/onnx/mx2onnx/_export_model.py @@ -51,7 +51,7 @@ def get_operator_support(opset_version=None): def export_model(sym, params, in_shapes=None, in_types=np.float32, onnx_file_path='model.onnx', verbose=False, dynamic=False, dynamic_input_shapes=None, run_shape_inference=False, input_type=None, - input_shape=None, model_specific_logics=None): + input_shape=None, model_specific_logics=None, cheat_sheet=None): """Exports the MXNet model file, passed as a parameter, into ONNX model. Accepts both symbol,parameter objects as well as json and params filepaths as input. Operator support and coverage - @@ -125,13 +125,15 @@ def export_model(sym, params, in_shapes=None, in_types=np.float32, in_types_t, verbose=verbose, opset_version=opset_version, dynamic=dynamic, dynamic_input_shapes=dynamic_input_shapes, - model_specific_logics=model_specific_logics) + model_specific_logics=model_specific_logics, + cheat_sheet=cheat_sheet) elif isinstance(sym, symbol.Symbol) and isinstance(params, dict): onnx_graph = converter.create_onnx_graph_proto(sym, params, in_shapes, in_types_t, verbose=verbose, opset_version=opset_version, dynamic=dynamic, dynamic_input_shapes=dynamic_input_shapes, - model_specific_logics=model_specific_logics) + model_specific_logics=model_specific_logics, + cheat_sheet=cheat_sheet) elif isinstance(sym, symbol.Symbol) and isinstance(params, list) and len(params) == 2: # when params contains arg_params and aux_params p = {} diff --git a/python/mxnet/onnx/mx2onnx/_export_onnx.py b/python/mxnet/onnx/mx2onnx/_export_onnx.py index 454d6eed14cc..c2046eab78ba 100644 --- a/python/mxnet/onnx/mx2onnx/_export_onnx.py +++ b/python/mxnet/onnx/mx2onnx/_export_onnx.py @@ -249,8 +249,9 @@ def convert_weights_to_numpy(weights_dict): return dict([(k.replace("arg:", "").replace("aux:", ""), v.asnumpy()) for k, v in weights_dict.items()]) - def create_onnx_graph_proto(self, sym, params, in_shapes, in_types, verbose=False, opset_version=None, - dynamic=True, dynamic_input_shapes=None, model_specific_logics=None): + def create_onnx_graph_proto(self, sym, params, in_shapes, in_types, verbose=False, + opset_version=None, dynamic=True, dynamic_input_shapes=None, + model_specific_logics=None, cheat_sheet=None): """Convert MXNet graph to ONNX graph Parameters @@ -363,7 +364,8 @@ def __init__(self, name, dtype): outputs_lookup=outputs_lookup, idx=idx, opset_version=opset_version, - model_specific_logics=model_specific_logics + model_specific_logics=model_specific_logics, + cheat_sheet=cheat_sheet ) if isinstance(converted, list): # Collect all the node's output names diff --git a/python/mxnet/onnx/mx2onnx/_op_translations/__init__.py b/python/mxnet/onnx/mx2onnx/_op_translations/__init__.py index f3c8cbd889b1..663d36fd5e59 100644 --- a/python/mxnet/onnx/mx2onnx/_op_translations/__init__.py +++ b/python/mxnet/onnx/mx2onnx/_op_translations/__init__.py @@ -18,6 +18,6 @@ # coding: utf-8 """ONNX export op translation""" -from . import _gluonnlp_bert +from . import _gluonnlp_bert_uninterleaved from . import _op_translations_opset12 from . import _op_translations_opset13 diff --git a/python/mxnet/onnx/mx2onnx/_op_translations/_gluonnlp_bert.py b/python/mxnet/onnx/mx2onnx/_op_translations/_gluonnlp_bert_uninterleaved.py similarity index 95% rename from python/mxnet/onnx/mx2onnx/_op_translations/_gluonnlp_bert.py rename to python/mxnet/onnx/mx2onnx/_op_translations/_gluonnlp_bert_uninterleaved.py index 51186f974363..5b8535d48029 100644 --- a/python/mxnet/onnx/mx2onnx/_op_translations/_gluonnlp_bert.py +++ b/python/mxnet/onnx/mx2onnx/_op_translations/_gluonnlp_bert_uninterleaved.py @@ -19,20 +19,22 @@ # coding: utf-8 """GluonNLP BERT specific translation logics""" import numpy as np +import logging from .._export_onnx import MXNetGraph as mx_op try: import onnx except ImportError: onnx = None - -def get_cheat_sheet(): - cheat_sheet = { - 'qkv_hidden': 768, - 'num_heads': 12, - 'head_dim': 64 - } - +def get_cheat_sheet(kwargs): + cheat_sheet = kwargs.get('cheat_sheet', None) + if cheat_sheet is None: + logging.warning('cheat_sheet not found, using default vallues') + cheat_sheet = { + 'qkv_hidden': 768, + 'num_heads': 12, + 'head_dim': 64 + } return cheat_sheet @@ -117,7 +119,7 @@ def create_tensor(tensor_list, tensor_name, initializer, dtype='int64'): ) -@mx_op.register("FullyConnected", opset_version='gluonnlp_bert') +@mx_op.register("FullyConnected", opset_version='gluonnlp_bert_uninterleaved') def convert_fully_connected(node, **kwargs): """Map MXNet's FullyConnected operator attributes to onnx's Gemm operator and return the created node. @@ -133,7 +135,7 @@ def convert_fully_connected(node, **kwargs): nodes = [] if 'dotproductselfattentioncell' in name: - cheat_sheet = get_cheat_sheet() + cheat_sheet = get_cheat_sheet(kwargs) qkv_hidden = cheat_sheet['qkv_hidden'] num_heads = cheat_sheet['num_heads'] head_dim = cheat_sheet['head_dim'] @@ -200,7 +202,7 @@ def convert_fully_connected(node, **kwargs): return nodes -@mx_op.register("_contrib_interleaved_matmul_selfatt_qk", opset_version='gluonnlp_bert') +@mx_op.register("_contrib_interleaved_matmul_selfatt_qk", opset_version='gluonnlp_bert_uninterleaved') def convert_matmul_selfatt_qk(node, **kwargs): """Map MXNet's _contrib_interleaved_matmul_selfatt_qk operator """ @@ -213,7 +215,7 @@ def convert_matmul_selfatt_qk(node, **kwargs): node['inputs'] = [inp0, inp1] name, input_nodes, _ = get_inputs(node, kwargs) - cheat_sheet = get_cheat_sheet() + cheat_sheet = get_cheat_sheet(kwargs) num_heads = cheat_sheet['num_heads'] head_dim = cheat_sheet['head_dim'] @@ -236,7 +238,7 @@ def convert_matmul_selfatt_qk(node, **kwargs): return nodes -@mx_op.register("_contrib_interleaved_matmul_selfatt_valatt", opset_version='gluonnlp_bert') +@mx_op.register("_contrib_interleaved_matmul_selfatt_valatt", opset_version='gluonnlp_bert_uninterleaved') def convert_contrib_interleaved_matmul_selfatt_valatt(node, **kwargs): """Map MXNet's _contrib_interleaved_matmul_selfatt_valatt operator attributes to onnx's operator. """ @@ -247,7 +249,7 @@ def convert_contrib_interleaved_matmul_selfatt_valatt(node, **kwargs): node['inputs'] = [inp0, inp1] name, input_nodes, _ = get_inputs(node, kwargs) - cheat_sheet = get_cheat_sheet() + cheat_sheet = get_cheat_sheet(kwargs) num_heads = cheat_sheet['num_heads'] head_dim = cheat_sheet['head_dim'] diff --git a/tests/python-pytest/onnx/test_onnxruntime_nlp.py b/tests/python-pytest/onnx/test_onnxruntime_nlp.py index b5200757c75b..a917f249ea4e 100644 --- a/tests/python-pytest/onnx/test_onnxruntime_nlp.py +++ b/tests/python-pytest/onnx/test_onnxruntime_nlp.py @@ -30,7 +30,8 @@ @with_seed() @pytest.mark.parametrize('model_name', ['roberta_24_1024_16', 'roberta_12_768_12']) -def test_roberta_inference_onnxruntime(tmp_path, model_name): +@pytest.mark.parametrize('model_specific_logics', [None, 'gluonnlp_bert_uninterleaved']) +def test_roberta_inference_onnxruntime(tmp_path, model_name, model_specific_logics): tmp_path = str(tmp_path) try: import gluonnlp as nlp @@ -63,9 +64,19 @@ def test_roberta_inference_onnxruntime(tmp_path, model_name): params_file = "%s-0000.params" % prefix onnx_file = "%s.onnx" % prefix input_shapes = [(batch, seq_length), (batch,), (batch, num_masked_positions)] + + # this is only for when model_specific_logics='gluonnlp_bert_uninterleaved' + cheat_sheet = { + 'qkv_hidden': 768 if model_name == 'roberta_12_768_12' else 1024, + 'num_heads': 12 if model_name == 'roberta_12_768_12' else 16, + 'head_dim': 64 + } + converted_model_path = mx.onnx.export_model(sym_file, params_file, input_shapes, [np.float32, np.float32, np.int32], - onnx_file, verbose=True) + onnx_file, verbose=True, + model_specific_logics=model_specific_logics, + cheat_sheet=cheat_sheet) sess_options = onnxruntime.SessionOptions() sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL @@ -84,15 +95,16 @@ def test_roberta_inference_onnxruntime(tmp_path, model_name): @with_seed() @pytest.mark.integration -@pytest.mark.parametrize('model', ['bert_12_768_12', 'bert_24_1024_16']) -def test_bert_inference_onnxruntime(tmp_path, model): +@pytest.mark.parametrize('model_name', ['bert_12_768_12', 'bert_24_1024_16']) +@pytest.mark.parametrize('model_specific_logics', [None, 'gluonnlp_bert_uninterleaved']) +def test_bert_inference_onnxruntime(tmp_path, model_name, model_specific_logics): tmp_path = str(tmp_path) try: import gluonnlp as nlp dataset = 'book_corpus_wiki_en_uncased' ctx = mx.cpu(0) model, vocab = nlp.model.get_model( - name=model, + name=model_name, ctx=ctx, dataset_name=dataset, pretrained=True, @@ -117,10 +129,19 @@ def test_bert_inference_onnxruntime(tmp_path, model): params_file = "%s-0000.params" % prefix onnx_file = "%s.onnx" % prefix + # this is only for when model_specific_logics='gluonnlp_bert_uninterleaved' + cheat_sheet = { + 'qkv_hidden': 768 if model_name == 'bert_12_768_12' else 1024, + 'num_heads': 12 if model_name == 'bert_12_768_12' else 16, + 'head_dim': 64 + } input_shapes = [(batch, seq_length), (batch, seq_length), (batch,)] input_types = [np.float32, np.float32, np.float32] - converted_model_path = mx.onnx.export_model(sym_file, params_file, input_shapes, input_types, onnx_file) + converted_model_path = mx.onnx.export_model(sym_file, params_file, input_shapes, + input_types, onnx_file, + model_specific_logics=model_specific_logics, + cheat_sheet=cheat_sheet) # create onnxruntime session using the generated onnx file @@ -140,7 +161,8 @@ def test_bert_inference_onnxruntime(tmp_path, model): @with_seed() @pytest.mark.parametrize('model_name', ['distilbert_6_768_12']) -def test_distilbert_inference_onnxruntime(tmp_path, model_name): +@pytest.mark.parametrize('model_specific_logics', [None, 'gluonnlp_bert_uninterleaved']) +def test_distilbert_inference_onnxruntime(tmp_path, model_name, model_specific_logics): tmp_path = str(tmp_path) try: import gluonnlp as nlp @@ -169,9 +191,18 @@ def test_distilbert_inference_onnxruntime(tmp_path, model_name): onnx_file = "%s.onnx" % prefix input_shapes = [(batch, seq_length), (batch,)] + # this is only for when model_specific_logics='gluonnlp_bert_uninterleaved' + cheat_sheet = { + 'qkv_hidden': 768, + 'num_heads': 12, + 'head_dim': 64 + } converted_model_path = mx.onnx.export_model(sym_file, params_file, input_shapes, [np.float32, np.float32], - onnx_file, verbose=True) + onnx_file, verbose=True, + model_specific_logics=model_specific_logics, + cheat_sheet=cheat_sheet) + sess_options = onnxruntime.SessionOptions() sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL sess = onnxruntime.InferenceSession(onnx_file, sess_options) @@ -372,7 +403,8 @@ def test_awd_rnn_lstm_pretrained_inference_onnxruntime(tmp_path, model_name, seq @with_seed() @pytest.mark.parametrize('model_name', ['ernie_12_768_12']) -def test_ernie_inference_onnxruntime(tmp_path, model_name): +@pytest.mark.parametrize('model_specific_logics', [None, 'gluonnlp_bert_uninterleaved']) +def test_ernie_inference_onnxruntime(tmp_path, model_name, model_specific_logics): tmp_path = str(tmp_path) try: import gluonnlp as nlp @@ -408,8 +440,18 @@ def test_ernie_inference_onnxruntime(tmp_path, model_name): input_shapes = [(batch, seq_length), (batch, seq_length), (batch,)] input_types = [np.float32, np.float32, np.float32] + + # this is only for when model_specific_logics='gluonnlp_bert_uninterleaved' + cheat_sheet = { + 'qkv_hidden': 768, + 'num_heads': 12, + 'head_dim': 64 + } + converted_model_path = mx.onnx.export_model(sym_file, params_file, input_shapes, - input_types, onnx_file) + input_types, onnx_file, + model_specific_logics=model_specific_logics, + cheat_sheet=cheat_sheet) # create onnxruntime session using the generated onnx file ses_opt = onnxruntime.SessionOptions() From 3ab73cb7a4a9a9cfd3391655861dd35d811dec55 Mon Sep 17 00:00:00 2001 From: zha0q1 Date: Thu, 6 May 2021 21:18:47 +0000 Subject: [PATCH 5/6] add doc sring --- python/mxnet/onnx/mx2onnx/_export_model.py | 5 ++++- python/mxnet/onnx/mx2onnx/_export_onnx.py | 3 +++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/python/mxnet/onnx/mx2onnx/_export_model.py b/python/mxnet/onnx/mx2onnx/_export_model.py index 3d54c78e0958..7287c7c980e2 100644 --- a/python/mxnet/onnx/mx2onnx/_export_model.py +++ b/python/mxnet/onnx/mx2onnx/_export_model.py @@ -83,8 +83,11 @@ def export_model(sym, params, in_shapes=None, in_types=np.float32, This is the old name of in_types. We keep this parameter name for backward compatibility in_shapes : List of tuple This is the old name of in_shapes. We keep this parameter name for backward compatibility - model_specific_logics: str + model_specific_logics : str Specifies if model-specific conversion logic should be used. Refer to ./_op_translations/ + cheat_sheet : dict of str to str + This is a dict that stors some hyperparameters values or additional info about the model that + would be used in model-specific conversion functions Returns ------- diff --git a/python/mxnet/onnx/mx2onnx/_export_onnx.py b/python/mxnet/onnx/mx2onnx/_export_onnx.py index c2046eab78ba..07b27908f853 100644 --- a/python/mxnet/onnx/mx2onnx/_export_onnx.py +++ b/python/mxnet/onnx/mx2onnx/_export_onnx.py @@ -274,6 +274,9 @@ def create_onnx_graph_proto(self, sym, params, in_shapes, in_types, verbose=Fals Specifies the dynamic input_shapes. If None then all dimensions are set to None model_specific_logics: str Specifies if model-specific conversion logic should be used. Refer to ./_op_translations/ + cheat_sheet : dict of str to str + This is a dict that stors some hyperparameters values or additional info about the model that + would be used in model-specific conversion functions Returns ------- From cdf05967e6e83b3f3af6346609d2667f23f87f64 Mon Sep 17 00:00:00 2001 From: zha0q1 Date: Thu, 6 May 2021 21:59:03 +0000 Subject: [PATCH 6/6] fix sanity --- .../mx2onnx/_op_translations/_gluonnlp_bert_uninterleaved.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/python/mxnet/onnx/mx2onnx/_op_translations/_gluonnlp_bert_uninterleaved.py b/python/mxnet/onnx/mx2onnx/_op_translations/_gluonnlp_bert_uninterleaved.py index 5b8535d48029..8ddc9a9e6db5 100644 --- a/python/mxnet/onnx/mx2onnx/_op_translations/_gluonnlp_bert_uninterleaved.py +++ b/python/mxnet/onnx/mx2onnx/_op_translations/_gluonnlp_bert_uninterleaved.py @@ -18,8 +18,8 @@ # coding: utf-8 """GluonNLP BERT specific translation logics""" -import numpy as np import logging +import numpy as np from .._export_onnx import MXNetGraph as mx_op try: import onnx @@ -33,8 +33,7 @@ def get_cheat_sheet(kwargs): cheat_sheet = { 'qkv_hidden': 768, 'num_heads': 12, - 'head_dim': 64 - } + 'head_dim': 64} return cheat_sheet