@@ -154,6 +154,29 @@ def create_basic_op_node(op_name, node, kwargs):
154154 )
155155 return [node ]
156156
157+ def create_const_scalar_node (input_name , value , kwargs ):
158+ """Helper function to create a tensor value node and a
159+ initializer tensor node with constant value."""
160+ from onnx .helper import make_tensor , make_tensor_value_info
161+ initializer = kwargs ["initializer" ]
162+ input_type = onnx .mapping .NP_TYPE_TO_TENSOR_TYPE [value .dtype ]
163+ value_node = make_tensor_value_info (input_name , input_type , ())
164+ tensor_node = make_tensor (input_name , input_type , (), (value ,))
165+ initializer .append (tensor_node )
166+ return value_node
167+
168+ def create_const_node (input_name , value , kwargs ):
169+ """Helper function to create a tensor value node and a
170+ initializer tensor node with constant value."""
171+ from onnx .helper import make_tensor , make_tensor_value_info
172+ initializer = kwargs ["initializer" ]
173+ input_type = onnx .mapping .NP_TYPE_TO_TENSOR_TYPE [value .dtype ]
174+ input_shape = value .shape
175+ value_node = make_tensor_value_info (input_name , input_type , input_shape )
176+ tensor_node = make_tensor (input_name , input_type , input_shape , value )
177+ initializer .append (tensor_node )
178+ return value_node
179+
157180@mx_op .register ("null" )
158181def convert_weights_and_inputs (node , ** kwargs ):
159182 """Helper function to convert weights and inputs.
@@ -802,6 +825,7 @@ def convert_leakyrelu(node, **kwargs):
802825 """Map MXNet's LeakyReLU operator attributes to onnx's Elu/LeakyRelu/PRelu operators
803826 based on the input node's attributes and return the created node.
804827 """
828+ from onnx .helper import make_node
805829 name , input_nodes , attrs = get_inputs (node , kwargs )
806830
807831 act_type = attrs .get ("act_type" , "leaky" )
@@ -816,6 +840,19 @@ def convert_leakyrelu(node, **kwargs):
816840 inputs = input_nodes ,
817841 outputs = [name ],
818842 name = name )
843+ elif act_type in ('gelu' ):
844+ sqrt2 = np .float32 (1.4142135623730951 )
845+ nodes = [
846+ create_const_scalar_node (name + "_sqrt2" , sqrt2 , kwargs ),
847+ make_node ("Div" , [input_nodes [0 ], name + "_sqrt2" ], [name + "_div0_out" ]),
848+ make_node ("Erf" , [name + "_div0_out" ], [name + "_erf0_out" ]),
849+ create_const_scalar_node (name + "_one" , np .float32 (1.0 ), kwargs ),
850+ create_const_scalar_node (name + "_half" , np .float32 (0.5 ), kwargs ),
851+ make_node ("Add" , [name + "_erf0_out" , name + "_one" ], [name + "_add0_out" ]),
852+ make_node ("Mul" , [input_nodes [0 ], name + "_add0_out" ], [name + "_mul0_out" ]),
853+ make_node ("Mul" , [name + "_mul0_out" , name + "_half" ], [name ])
854+ ]
855+ return nodes
819856 else :
820857 node = onnx .helper .make_node (
821858 act_name [act_type ],
@@ -2214,3 +2251,152 @@ def convert_take(node, **kwargs):
22142251 name = name ,
22152252 )
22162253 return [node ]
2254+
2255+
2256+ @mx_op .register ("LayerNorm" )
2257+ def convert_layer_norm (node , ** kwargs ):
2258+ """Map MXNet's LayerNorm operator attributes to onnx operators.
2259+ """
2260+ from onnx .helper import make_node
2261+ name , input_nodes , attrs = get_inputs (node , kwargs )
2262+
2263+ in_shape = kwargs ['in_shape' ]
2264+ axes = [- i for i in range (len (in_shape [0 ]), 0 , - 1 )]
2265+ eps = attrs .get ('eps' )
2266+ nodes = [
2267+ make_node ("ReduceMean" , [input_nodes [0 ]], [name + "_rm0_out" ], axes = axes ),
2268+ make_node ("Sub" , [input_nodes [0 ], name + "_rm0_out" ], [name + "_sub0_out" ]),
2269+ create_const_scalar_node (name + "_two" , np .float32 (2. ), kwargs ),
2270+ make_node ("Pow" , [name + "_sub0_out" , name + "_two" ], [name + "_pow0_out" ]),
2271+ make_node ("ReduceMean" , [name + "_pow0_out" ], [name + "_rm1_out" ], axes = axes ),
2272+ create_const_scalar_node (name + "_eps" , np .float32 (eps ), kwargs ),
2273+ make_node ("Add" , [name + "_rm1_out" , name + "_eps" ], [name + "_add0_out" ]),
2274+ make_node ("Sqrt" , [name + "_add0_out" ], [name + "_sqrt0_out" ]),
2275+ make_node ("Div" , [name + "_sub0_out" , name + "_sqrt0_out" ], [name + "_div0_out" ]),
2276+ make_node ("Mul" , [name + "_div0_out" , input_nodes [1 ]], [name + "_mul0_out" ]),
2277+ make_node ("Add" , [name + "_mul0_out" , input_nodes [2 ]], [name ], name )
2278+ ]
2279+
2280+ return nodes
2281+
2282+
2283+ @mx_op .register ("Embedding" )
2284+ def convert_embedding (node , ** kwargs ):
2285+ """Map MXNet's Embedding operator attributes to onnx's
2286+ Gather operator."""
2287+ name , input_nodes , attrs = get_inputs (node , kwargs )
2288+ axis = int (attrs .get ('axis' , 0 ))
2289+ node = onnx .helper .make_node (
2290+ "Gather" ,
2291+ input_nodes ,
2292+ [name ],
2293+ axis = axis ,
2294+ name = name
2295+ )
2296+ return [node ]
2297+
2298+
2299+ @mx_op .register ("stack" )
2300+ def convert_stack (node , ** kwargs ):
2301+ """Map MXNet's stack operator to onnx operators.
2302+ """
2303+ name , input_nodes , attrs = get_inputs (node , kwargs )
2304+ axis = int (attrs .get ('axis' , 0 ))
2305+ idx = 0
2306+ nodes = []
2307+ for input_node in input_nodes :
2308+ nodes .append (onnx .helper .make_node (
2309+ "Unsqueeze" ,
2310+ inputs = [input_node ],
2311+ outputs = [name + "_unsqueeze" + str (idx )],
2312+ axes = [axis ]
2313+ ))
2314+ idx += 1
2315+
2316+ nodes .append (onnx .helper .make_node (
2317+ "Concat" ,
2318+ inputs = [name + "_unsqueeze" + str (i ) for i in range (len (nodes ))],
2319+ outputs = [name ],
2320+ name = name ,
2321+ axis = axis
2322+ ))
2323+ return nodes
2324+
2325+
2326+ @mx_op .register ("slice" )
2327+ def convert_slice (node , ** kwargs ):
2328+ """Map MXNet's slice operator to onnx Slice operator."""
2329+ name , input_nodes , attrs = get_inputs (node , kwargs )
2330+ starts = convert_string_to_list (attrs .get ("begin" ))
2331+ ends = convert_string_to_list (attrs .get ("end" ))
2332+ steps = attrs .get ("step" , [])
2333+ nodes = [
2334+ create_const_node (name + "_begin" , np .array (starts ), kwargs ),
2335+ create_const_node (name + "_end" , np .array (ends ), kwargs )
2336+ ]
2337+ inputs = [input_nodes [0 ], name + "_begin" , name + "_end" ]
2338+ if len (steps ) > 0 :
2339+ nodes .append (create_const_node (name + "_steps" , np .array (steps , dtype = 'int64' ), kwargs ))
2340+ inputs .append (name + "_steps" )
2341+ nodes .append (onnx .helper .make_node ("Slice" , inputs , [name ], name = name ))
2342+ return nodes
2343+
2344+
2345+ @mx_op .register ("zeros_like" )
2346+ def convert_zeros_like (node , ** kwargs ):
2347+ """Map MXNet's zeros_like operator attributes to onnx's ConstantOfShape operator.
2348+ """
2349+ from onnx .helper import make_node , make_tensor
2350+ name , _ , _ = get_inputs (node , kwargs )
2351+
2352+ # create tensor with shape of input
2353+ create_const_node (name + "_shape" , np .array (kwargs ['in_shape' ][0 ], dtype = 'int64' ), kwargs )
2354+ tensor_value = make_tensor (name + "_zero" , kwargs ['in_type' ], [1 ], [0 ])
2355+ nodes = [
2356+ make_node ("ConstantOfShape" , [name + "_shape" ], [name ], value = tensor_value )
2357+ ]
2358+ return nodes
2359+
2360+
2361+ @mx_op .register ("_contrib_arange_like" )
2362+ def convert_arange_like (node , ** kwargs ):
2363+ """Map MXNet's arange_like operator attributes to onnx's Range and Reshape operators.
2364+ """
2365+ from onnx .helper import make_node
2366+ name , _ , attrs = get_inputs (node , kwargs )
2367+
2368+ opset_version = kwargs ['opset_version' ]
2369+ if opset_version < 11 :
2370+ raise AttributeError ("ONNX opset 11 or greater is required to export this operator" )
2371+
2372+ input_type = kwargs ['in_type' ]
2373+ dtype = onnx .mapping .TENSOR_TYPE_TO_NP_TYPE [input_type ]
2374+ in_shape = kwargs ['in_shape' ]
2375+ axis = attrs .get ('axis' )
2376+
2377+ if axis is None :
2378+ # output will be same shape as input
2379+ output_shape = in_shape [0 ]
2380+ else :
2381+ # determine shape of axis
2382+ output_shape = [in_shape [0 ][int (axis )]]
2383+
2384+ start = np .array ([attrs .get ('start' , 0. )], dtype = dtype )
2385+ step = np .array ([attrs .get ('step' , 1. )], dtype = dtype )
2386+ repeat = np .array ([attrs .get ('repeat' , 1 )], dtype = dtype )
2387+ if repeat != 1 :
2388+ raise NotImplementedError ("arange_like operator with repeat != 1 not yet implemented." )
2389+
2390+ tot_elements = np .prod (output_shape )
2391+ limit = np .array ([start + (tot_elements * step )], dtype = dtype )
2392+
2393+ # create constant inputs
2394+ nodes = [
2395+ create_const_scalar_node (name + "_start" , start , kwargs ),
2396+ create_const_scalar_node (name + "_limit" , limit , kwargs ),
2397+ create_const_scalar_node (name + "_step" , step , kwargs ),
2398+ create_const_node (name + "_shape" , np .array (output_shape , dtype = 'int64' ), kwargs ),
2399+ make_node ("Range" , [name + "_start" , name + "_limit" , name + "_step" ], [name + "_range0_out" ]),
2400+ make_node ("Reshape" , [name + "_range0_out" , name + "_shape" ], [name ])
2401+ ]
2402+ return nodes
0 commit comments