Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit d974cba

Browse files
waytrue17Wei Chu
andauthored
[v1.x] ONNX export rewrite tile (#19868)
* fix tile * fix sanity Co-authored-by: Wei Chu <weichu@amazon.com>
1 parent 26afc44 commit d974cba

File tree

2 files changed

+29
-26
lines changed

2 files changed

+29
-26
lines changed

python/mxnet/contrib/onnx/mx2onnx/_op_translations.py

Lines changed: 21 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2262,37 +2262,32 @@ def convert_tile(node, **kwargs):
22622262
"""Map MXNet's Tile operator attributes to onnx's Tile
22632263
operator and return the created node.
22642264
"""
2265+
from onnx.helper import make_node
22652266
name, input_nodes, attrs = get_inputs(node, kwargs)
22662267

2267-
reps_list = convert_string_to_list(attrs["reps"])
2268-
2269-
initializer = kwargs["initializer"]
2270-
reps_shape_np = np.array(reps_list, dtype='int64')
2271-
data_type = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[reps_shape_np.dtype]
2272-
dims = np.shape(reps_shape_np)
2273-
2274-
output_shape_name = "reps_attr_tensor" + str(kwargs["idx"])
2275-
tensor_node = onnx.helper.make_tensor_value_info(output_shape_name, data_type, dims)
2268+
data = input_nodes[0]
2269+
reps = convert_string_to_list(attrs["reps"])
22762270

2277-
initializer.append(
2278-
onnx.helper.make_tensor(
2279-
name=output_shape_name,
2280-
data_type=data_type,
2281-
dims=dims,
2282-
vals=reps_list,
2283-
raw=False,
2284-
)
2285-
)
2271+
create_tensor([0], name+'_0', kwargs['initializer'])
2272+
create_tensor([1], name+'_1', kwargs['initializer'])
2273+
create_tensor(reps, name+'_reps', kwargs['initializer'], dtype='int64')
2274+
create_tensor([len(reps)], name+'_reps_len', kwargs['initializer'])
22862275

2287-
input_nodes.append(output_shape_name)
2288-
tile_node = onnx.helper.make_node(
2289-
"Tile",
2290-
input_nodes,
2291-
[name],
2292-
name=name
2293-
)
2276+
nodes = [
2277+
make_node('Shape', [data], [name+'_data_shape']),
2278+
make_node('Shape', [name+'_data_shape'], [name+'_data_dim']),
2279+
make_node('Max', [name+'_data_dim', name+'_reps_len'], [name+'_max']),
2280+
make_node('Sub', [name+'_max', name+'_data_dim'], [name+'_data_diff']),
2281+
make_node('Concat', [name+'_data_diff', name+'_0'], [name+'_concat0_out'], axis=0),
2282+
make_node('Pad', [name+'_data_shape', name+'_concat0_out', name+'_1'], [name+'_data_shape_pad']),
2283+
make_node('Reshape', [data, name+'_data_shape_pad'], [name+'_data']),
2284+
make_node('Sub', [name+'_max', name+'_reps_len'], [name+'_reps_diff']),
2285+
make_node('Concat', [name+'_reps_diff', name+'_0'], [name+'_concat1_out'], axis=0),
2286+
make_node('Pad', [name+'_reps', name+'_concat1_out', name+'_1'], [name+'_reps_pad']),
2287+
make_node('Tile', [name+'_data', name+'_reps_pad'], [name], name=name),
2288+
]
22942289

2295-
return [tensor_node, tile_node]
2290+
return nodes
22962291

22972292

22982293
@mx_op.register("broadcast_to")

tests/python-pytest/onnx/test_operators.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1131,3 +1131,11 @@ def test_onnx_export_argsort(tmp_path, dtype, axis, is_ascend, dtype_i):
11311131
kwargs['is_ascend'] = is_ascend
11321132
M = def_model('argsort', axis=axis, dtype=dtype_i, **kwargs)
11331133
op_export_test('argsort', M, [A], tmp_path)
1134+
1135+
1136+
@pytest.mark.parametrize('dtype', ['int32', 'int64', 'float16', 'float32', 'float64'])
1137+
@pytest.mark.parametrize('reps', [(2, 3), (2, ), (2, 3, 4)])
1138+
def test_onnx_export_tile(tmp_path, dtype, reps):
1139+
x = mx.nd.random.normal(0, 100, (5, 6)).astype(dtype)
1140+
M = def_model('tile', reps=reps)
1141+
op_export_test('tile', M, [x], tmp_path)

0 commit comments

Comments
 (0)