@@ -2262,37 +2262,32 @@ def convert_tile(node, **kwargs):
22622262 """Map MXNet's Tile operator attributes to onnx's Tile
22632263 operator and return the created node.
22642264 """
2265+ from onnx .helper import make_node
22652266 name , input_nodes , attrs = get_inputs (node , kwargs )
22662267
2267- reps_list = convert_string_to_list (attrs ["reps" ])
2268-
2269- initializer = kwargs ["initializer" ]
2270- reps_shape_np = np .array (reps_list , dtype = 'int64' )
2271- data_type = onnx .mapping .NP_TYPE_TO_TENSOR_TYPE [reps_shape_np .dtype ]
2272- dims = np .shape (reps_shape_np )
2273-
2274- output_shape_name = "reps_attr_tensor" + str (kwargs ["idx" ])
2275- tensor_node = onnx .helper .make_tensor_value_info (output_shape_name , data_type , dims )
2268+ data = input_nodes [0 ]
2269+ reps = convert_string_to_list (attrs ["reps" ])
22762270
2277- initializer .append (
2278- onnx .helper .make_tensor (
2279- name = output_shape_name ,
2280- data_type = data_type ,
2281- dims = dims ,
2282- vals = reps_list ,
2283- raw = False ,
2284- )
2285- )
2271+ create_tensor ([0 ], name + '_0' , kwargs ['initializer' ])
2272+ create_tensor ([1 ], name + '_1' , kwargs ['initializer' ])
2273+ create_tensor (reps , name + '_reps' , kwargs ['initializer' ], dtype = 'int64' )
2274+ create_tensor ([len (reps )], name + '_reps_len' , kwargs ['initializer' ])
22862275
2287- input_nodes .append (output_shape_name )
2288- tile_node = onnx .helper .make_node (
2289- "Tile" ,
2290- input_nodes ,
2291- [name ],
2292- name = name
2293- )
2276+ nodes = [
2277+ make_node ('Shape' , [data ], [name + '_data_shape' ]),
2278+ make_node ('Shape' , [name + '_data_shape' ], [name + '_data_dim' ]),
2279+ make_node ('Max' , [name + '_data_dim' , name + '_reps_len' ], [name + '_max' ]),
2280+ make_node ('Sub' , [name + '_max' , name + '_data_dim' ], [name + '_data_diff' ]),
2281+ make_node ('Concat' , [name + '_data_diff' , name + '_0' ], [name + '_concat0_out' ], axis = 0 ),
2282+ make_node ('Pad' , [name + '_data_shape' , name + '_concat0_out' , name + '_1' ], [name + '_data_shape_pad' ]),
2283+ make_node ('Reshape' , [data , name + '_data_shape_pad' ], [name + '_data' ]),
2284+ make_node ('Sub' , [name + '_max' , name + '_reps_len' ], [name + '_reps_diff' ]),
2285+ make_node ('Concat' , [name + '_reps_diff' , name + '_0' ], [name + '_concat1_out' ], axis = 0 ),
2286+ make_node ('Pad' , [name + '_reps' , name + '_concat1_out' , name + '_1' ], [name + '_reps_pad' ]),
2287+ make_node ('Tile' , [name + '_data' , name + '_reps_pad' ], [name ], name = name ),
2288+ ]
22942289
2295- return [ tensor_node , tile_node ]
2290+ return nodes
22962291
22972292
22982293@mx_op .register ("broadcast_to" )
0 commit comments