Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 2127c3e

Browse files
waytrue17Wei Chu
andauthored
[v1.x] ONNX export support for RNN and sum_axis (#20226)
* export support RNN * add sum_axis * fix sanity * fix sanity * fix sanity * change regiester sum_axis Co-authored-by: Wei Chu <weichu@amazon.com>
1 parent 4056c07 commit 2127c3e

File tree

3 files changed

+212
-23
lines changed

3 files changed

+212
-23
lines changed

python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset12.py

Lines changed: 98 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2137,7 +2137,9 @@ def convert_square(node, **kwargs):
21372137
)
21382138
return [tensor_node, node]
21392139

2140+
# sum_axis is equivalent to sum in MXNet
21402141
@mx_op.register("sum")
2142+
@mx_op.register("sum_axis")
21412143
def convert_sum(node, **kwargs):
21422144
"""Map MXNet's sum operator attributes to onnx's ReduceSum operator
21432145
and return the created node.
@@ -4476,12 +4478,12 @@ def convert_RNN(node, **kwargs):
44764478
initial_h = input_nodes[2]
44774479

44784480
nodes = []
4481+
create_tensor([0], name+'_0', kwargs['initializer'])
44794482

44804483
mode = str(attrs.get('mode'))
44814484
if mode == 'lstm':
44824485
initial_c = input_nodes[3]
44834486
if num_layers == 2:
4484-
create_tensor([0], name+'_0', kwargs['initializer'])
44854487
create_tensor([8*state_size], name+'_8*state_size', kwargs['initializer'])
44864488
create_tensor([4*state_size*state_size], name+'_4*state_size^2', kwargs['initializer'])
44874489
create_tensor([1, 4*state_size, state_size], name+'_WR_shape', kwargs['initializer'])
@@ -4553,7 +4555,6 @@ def convert_RNN(node, **kwargs):
45534555
make_node('Concat', [name+'_lstm0_c', name+'_lstm1_c'], [name+'2'], axis=0),
45544556
]
45554557
elif num_layers == 1:
4556-
create_tensor([0], name+'_0', kwargs['initializer'])
45574558
create_tensor([1], name+'_1', kwargs['initializer'])
45584559
create_tensor([4*state_size], name+'_4*state_size', kwargs['initializer'])
45594560
create_tensor([8*state_size], name+'_8*state_size', kwargs['initializer'])
@@ -4598,7 +4599,6 @@ def convert_RNN(node, **kwargs):
45984599

45994600
elif mode == 'gru':
46004601
if num_layers == 2:
4601-
create_tensor([0], name+'_0', kwargs['initializer'])
46024602
create_tensor([6*state_size], name+'_6*state_size', kwargs['initializer'])
46034603
create_tensor([3*state_size*state_size], name+'_3*state_size^2', kwargs['initializer'])
46044604
create_tensor([1, 3*state_size, state_size], name+'_WR_shape', kwargs['initializer'])
@@ -4669,7 +4669,7 @@ def convert_RNN(node, **kwargs):
46694669
]
46704670

46714671
elif num_layers == 1:
4672-
create_tensor([0], name+'_0', kwargs['initializer'])
4672+
46734673
create_tensor([1], name+'_1', kwargs['initializer'])
46744674
create_tensor([3*state_size], name+'_3*state_size', kwargs['initializer'])
46754675
create_tensor([6*state_size], name+'_6*state_size', kwargs['initializer'])
@@ -4712,6 +4712,100 @@ def convert_RNN(node, **kwargs):
47124712
else:
47134713
raise NotImplementedError('Currently RNN onnx export only supports num_layers equals to 1 or 2')
47144714

4715+
elif mode in ['rnn_tanh', 'rnn_relu']:
4716+
activations = ['Tanh']
4717+
if mode == 'rnn_relu':
4718+
activations = ['Relu']
4719+
if num_layers == 2:
4720+
4721+
create_tensor([2*state_size], name+'_2*state_size', kwargs['initializer'])
4722+
create_tensor([state_size*state_size], name+'_state_size^2', kwargs['initializer'])
4723+
create_tensor([1, state_size, state_size], name+'_WR_shape', kwargs['initializer'])
4724+
create_tensor([1, 2*state_size], name+'_B_shape', kwargs['initializer'])
4725+
create_tensor([4*state_size*state_size], name+'_WR_offset', kwargs['initializer'])
4726+
4727+
nodes += [
4728+
make_node('Shape', [data], [name+'_data_shape']),
4729+
make_node('Split', [name+'_data_shape'], [name+'_seq_length', name+'_batch_size', name+'_input_size']),
4730+
4731+
# Layer 0
4732+
# get W
4733+
make_node('Slice', [param, name+'_0', name+'_state_size^2'], [name+'_W0_1d']),
4734+
make_node('Reshape', [name+'_W0_1d', name+'_WR_shape'], [name+'_W0']),
4735+
# get R
4736+
make_node('Add', [name+'_state_size^2', name+'_state_size^2'], [name+'_R0_offset']),
4737+
make_node('Slice', [param, name+'_state_size^2', name+'_R0_offset'], [name+'_R0_1d']),
4738+
make_node('Reshape', [name+'_R0_1d', name+'_WR_shape'], [name+'_R0']),
4739+
# get B
4740+
make_node('Add', [name+'_WR_offset', name+'_2*state_size'], [name+'_B0_offset']),
4741+
make_node('Slice', [param, name+'_WR_offset', name+'_B0_offset'], [name+'_B0_1d']),
4742+
make_node('Reshape', [name+'_B0_1d', name+'_B_shape'], [name+'_B0']),
4743+
# get initial states
4744+
make_node('Split', [initial_h], [name+'_initial_h0', name+'_initial_h1'], axis=0),
4745+
# get seq_len
4746+
make_node('Tile', [name+'_seq_length', name+'_batch_size'], [name+'_seq_len_']),
4747+
make_node("Cast", [name+'_seq_len_'], [name+"_seq_len"], to=int(TensorProto.INT32)),
4748+
# Layer 0 RNN
4749+
make_node('RNN', [data, name+'_W0', name+'_R0', name+'_B0', name+'_seq_len', name+'_initial_h0'],
4750+
[name+'_rnn0_out_', name+'_rnn0_h'], hidden_size=state_size, activations=activations),
4751+
make_node('Squeeze', [name+'_rnn0_out_'], [name+'_rnn0_out'], axes=[1]),
4752+
4753+
# Layer 1
4754+
# get W
4755+
make_node('Add', [name+'_R0_offset', name+'_state_size^2'], [name+'_W1_offset']),
4756+
make_node('Slice', [param, name+'_R0_offset', name+'_W1_offset'], [name+'_W1_1d']),
4757+
make_node('Reshape', [name+'_W1_1d', name+'_WR_shape'], [name+'_W1']),
4758+
# get R
4759+
make_node('Slice', [param, name+'_W1_offset', name+'_WR_offset'], [name+'_R1_1d']),
4760+
make_node('Reshape', [name+'_R1_1d', name+'_WR_shape'], [name+'_R1']),
4761+
# get B
4762+
make_node('Add', [name+'_B0_offset', name+'_2*state_size'], [name+'_B1_offset']),
4763+
make_node('Slice', [param, name+'_B0_offset', name+'_B1_offset'], [name+'_B1_1d']),
4764+
make_node('Reshape', [name+'_B1_1d', name+'_B_shape'], [name+'_B1']),
4765+
# Layer 1 RNN
4766+
make_node('RNN', [name+'_rnn0_out', name+'_W1', name+'_R1', name+'_B1', name+'_seq_len',
4767+
name+'_initial_h1'], [name+'_rnn1_out_', name+'_rnn1_h'],
4768+
hidden_size=state_size, activations=activations),
4769+
make_node('Squeeze', [name+'_rnn1_out_'], [name], axes=[1]),
4770+
make_node('Concat', [name+'_rnn0_h', name+'_rnn1_h'], [name+'1'], axis=0)
4771+
]
4772+
4773+
elif num_layers == 1:
4774+
4775+
create_tensor([1], name+'_1', kwargs['initializer'])
4776+
create_tensor([state_size], name+'_state_size', kwargs['initializer'])
4777+
create_tensor([2*state_size], name+'_2*state_size', kwargs['initializer'])
4778+
create_tensor([state_size*state_size], name+'_state_size^2', kwargs['initializer'])
4779+
create_tensor([1, state_size, state_size], name+'_R_shape', kwargs['initializer'])
4780+
create_tensor([1, 2*state_size], name+'_B_shape', kwargs['initializer'])
4781+
4782+
nodes += [
4783+
make_node('Shape', [data], [name+'_data_shape']),
4784+
make_node('Split', [name+'_data_shape'],
4785+
[name+'_seq_length', name+'_batch_size', name+'_input_size'], name='split0'),
4786+
# get W
4787+
make_node('Mul', [name+'_state_size', name+'_input_size'], [name+'_mul0']),
4788+
make_node('Slice', [param, name+'_0', name+'_mul0'], [name+'_W_1d']),
4789+
make_node('Concat', [name+'_1', name+'_state_size', name+'_input_size'], [name+'_W_shape'], axis=0),
4790+
make_node('Reshape', [name+'_W_1d', name+'_W_shape'], [name+'_W']),
4791+
# get R
4792+
make_node('Add', [name+'_mul0', name+'_state_size^2'], [name+'_add0']),
4793+
make_node('Slice', [param, name+'_mul0', name+'_add0'], [name+'_R_1d']),
4794+
make_node('Reshape', [name+'_R_1d', name+'_R_shape'], [name+'_R']),
4795+
# get B
4796+
make_node('Add', [name+'_add0', name+'_2*state_size'], [name+'_add1']),
4797+
make_node('Slice', [param, name+'_add0', name+'_add1'], [name+'_B_1d']),
4798+
make_node('Reshape', [name+'_B_1d', name+'_B_shape'], [name+'_B']),
4799+
# get seq_len
4800+
make_node('Tile', [name+'_seq_length', name+'_batch_size'], [name+'_seq_len_']),
4801+
make_node("Cast", [name+'_seq_len_'], [name+"_seq_len"], to=int(TensorProto.INT32)),
4802+
# compute RNN
4803+
make_node('RNN', [data, name+'_W', name+'_R', name+'_B', name+'_seq_len', initial_h],
4804+
[name+'0_', name+'1'], hidden_size=state_size, activations=activations),
4805+
make_node('Squeeze', [name+'0_'], [name], axes=[1]),
4806+
]
4807+
else:
4808+
raise NotImplementedError('Currently RNN onnx export only supports num_layers equals to 1 or 2')
47154809
else:
47164810
raise NotImplementedError(f"Currently RNN onnx export does not support {mode} mode")
47174811
return nodes

python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset13.py

Lines changed: 95 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1047,11 +1047,12 @@ def convert_RNN(node, **kwargs):
10471047
nodes = []
10481048

10491049
mode = str(attrs.get('mode'))
1050+
create_tensor([0], name+'_0', kwargs['initializer'])
10501051
create_tensor([1], name+'_1', kwargs['initializer'])
1052+
10511053
if mode == 'lstm':
10521054
initial_c = input_nodes[3]
10531055
if num_layers == 2:
1054-
create_tensor([0], name+'_0', kwargs['initializer'])
10551056
create_tensor([8*state_size], name+'_8*state_size', kwargs['initializer'])
10561057
create_tensor([4*state_size*state_size], name+'_4*state_size^2', kwargs['initializer'])
10571058
create_tensor([1, 4*state_size, state_size], name+'_WR_shape', kwargs['initializer'])
@@ -1123,7 +1124,6 @@ def convert_RNN(node, **kwargs):
11231124
make_node('Concat', [name+'_lstm0_c', name+'_lstm1_c'], [name+'2'], axis=0),
11241125
]
11251126
elif num_layers == 1:
1126-
create_tensor([0], name+'_0', kwargs['initializer'])
11271127
create_tensor([4*state_size], name+'_4*state_size', kwargs['initializer'])
11281128
create_tensor([8*state_size], name+'_8*state_size', kwargs['initializer'])
11291129
create_tensor([4*state_size*state_size], name+'_4*state_size^2', kwargs['initializer'])
@@ -1167,7 +1167,6 @@ def convert_RNN(node, **kwargs):
11671167

11681168
elif mode == 'gru':
11691169
if num_layers == 2:
1170-
create_tensor([0], name+'_0', kwargs['initializer'])
11711170
create_tensor([6*state_size], name+'_6*state_size', kwargs['initializer'])
11721171
create_tensor([3*state_size*state_size], name+'_3*state_size^2', kwargs['initializer'])
11731172
create_tensor([1, 3*state_size, state_size], name+'_WR_shape', kwargs['initializer'])
@@ -1238,7 +1237,6 @@ def convert_RNN(node, **kwargs):
12381237
]
12391238

12401239
elif num_layers == 1:
1241-
create_tensor([0], name+'_0', kwargs['initializer'])
12421240
create_tensor([3*state_size], name+'_3*state_size', kwargs['initializer'])
12431241
create_tensor([6*state_size], name+'_6*state_size', kwargs['initializer'])
12441242
create_tensor([3*state_size*state_size], name+'_3*state_size^2', kwargs['initializer'])
@@ -1272,14 +1270,106 @@ def convert_RNN(node, **kwargs):
12721270
# get seq_len
12731271
make_node('Tile', [name+'_seq_length', name+'_batch_size'], [name+'_seq_len_']),
12741272
make_node("Cast", [name+'_seq_len_'], [name+"_seq_len"], to=int(TensorProto.INT32)),
1275-
# compute LSTM
1273+
# compute GRU
12761274
make_node('GRU', [data, name+'_W', name+'_R', name+'_B', name+'_seq_len', initial_h],
12771275
[name+'0_', name+'1'], hidden_size=state_size, linear_before_reset=1),
12781276
make_node('Squeeze', [name+'0_', name+'_1'], [name]),
12791277
]
12801278
else:
12811279
raise NotImplementedError('Currently RNN onnx export only supports num_layers equals to 1 or 2')
12821280

1281+
elif mode in ['rnn_tanh', 'rnn_relu']:
1282+
activations = ['Tanh']
1283+
if mode == 'rnn_relu':
1284+
activations = ['Relu']
1285+
if num_layers == 2:
1286+
create_tensor([2*state_size], name+'_2*state_size', kwargs['initializer'])
1287+
create_tensor([state_size*state_size], name+'_state_size^2', kwargs['initializer'])
1288+
create_tensor([1, state_size, state_size], name+'_WR_shape', kwargs['initializer'])
1289+
create_tensor([1, 2*state_size], name+'_B_shape', kwargs['initializer'])
1290+
create_tensor([4*state_size*state_size], name+'_WR_offset', kwargs['initializer'])
1291+
1292+
nodes += [
1293+
make_node('Shape', [data], [name+'_data_shape']),
1294+
make_node('Split', [name+'_data_shape'], [name+'_seq_length', name+'_batch_size', name+'_input_size']),
1295+
1296+
# Layer 0
1297+
# get W
1298+
make_node('Slice', [param, name+'_0', name+'_state_size^2'], [name+'_W0_1d']),
1299+
make_node('Reshape', [name+'_W0_1d', name+'_WR_shape'], [name+'_W0']),
1300+
# get R
1301+
make_node('Add', [name+'_state_size^2', name+'_state_size^2'], [name+'_R0_offset']),
1302+
make_node('Slice', [param, name+'_state_size^2', name+'_R0_offset'], [name+'_R0_1d']),
1303+
make_node('Reshape', [name+'_R0_1d', name+'_WR_shape'], [name+'_R0']),
1304+
# get B
1305+
make_node('Add', [name+'_WR_offset', name+'_2*state_size'], [name+'_B0_offset']),
1306+
make_node('Slice', [param, name+'_WR_offset', name+'_B0_offset'], [name+'_B0_1d']),
1307+
make_node('Reshape', [name+'_B0_1d', name+'_B_shape'], [name+'_B0']),
1308+
# get initial states
1309+
make_node('Split', [initial_h], [name+'_initial_h0', name+'_initial_h1'], axis=0),
1310+
# get seq_len
1311+
make_node('Tile', [name+'_seq_length', name+'_batch_size'], [name+'_seq_len_']),
1312+
make_node("Cast", [name+'_seq_len_'], [name+"_seq_len"], to=int(TensorProto.INT32)),
1313+
# Layer 0 RNN
1314+
make_node('RNN', [data, name+'_W0', name+'_R0', name+'_B0', name+'_seq_len',
1315+
name+'_initial_h0'], [name+'_rnn0_out_', name+'_rnn0_h'],
1316+
hidden_size=state_size, activations=activations),
1317+
make_node('Squeeze', [name+'_rnn0_out_', name+'_1'], [name+'_rnn0_out']),
1318+
1319+
# Layer 1
1320+
# get W
1321+
make_node('Add', [name+'_R0_offset', name+'_state_size^2'], [name+'_W1_offset']),
1322+
make_node('Slice', [param, name+'_R0_offset', name+'_W1_offset'], [name+'_W1_1d']),
1323+
make_node('Reshape', [name+'_W1_1d', name+'_WR_shape'], [name+'_W1']),
1324+
# get R
1325+
make_node('Slice', [param, name+'_W1_offset', name+'_WR_offset'], [name+'_R1_1d']),
1326+
make_node('Reshape', [name+'_R1_1d', name+'_WR_shape'], [name+'_R1']),
1327+
# get B
1328+
make_node('Add', [name+'_B0_offset', name+'_2*state_size'], [name+'_B1_offset']),
1329+
make_node('Slice', [param, name+'_B0_offset', name+'_B1_offset'], [name+'_B1_1d']),
1330+
make_node('Reshape', [name+'_B1_1d', name+'_B_shape'], [name+'_B1']),
1331+
# Layer 1 RNN
1332+
make_node('RNN', [name+'_rnn0_out', name+'_W1', name+'_R1', name+'_B1', name+'_seq_len',
1333+
name+'_initial_h1'], [name+'_rnn1_out_', name+'_rnn1_h'],
1334+
hidden_size=state_size, activations=activations),
1335+
make_node('Squeeze', [name+'_rnn1_out_', name+'_1'], [name]),
1336+
make_node('Concat', [name+'_rnn0_h', name+'_rnn1_h'], [name+'1'], axis=0)
1337+
]
1338+
1339+
elif num_layers == 1:
1340+
create_tensor([state_size], name+'_state_size', kwargs['initializer'])
1341+
create_tensor([2*state_size], name+'_2*state_size', kwargs['initializer'])
1342+
create_tensor([state_size*state_size], name+'_state_size^2', kwargs['initializer'])
1343+
create_tensor([1, state_size, state_size], name+'_R_shape', kwargs['initializer'])
1344+
create_tensor([1, 2*state_size], name+'_B_shape', kwargs['initializer'])
1345+
1346+
nodes += [
1347+
make_node('Shape', [data], [name+'_data_shape']),
1348+
make_node('Split', [name+'_data_shape'], [name+'_seq_length', name+'_batch_size',
1349+
name+'_input_size'], name='split0'),
1350+
# get W
1351+
make_node('Mul', [name+'_state_size', name+'_input_size'], [name+'_mul0']),
1352+
make_node('Slice', [param, name+'_0', name+'_mul0'], [name+'_W_1d']),
1353+
make_node('Concat', [name+'_1', name+'_state_size', name+'_input_size'], [name+'_W_shape'], axis=0),
1354+
make_node('Reshape', [name+'_W_1d', name+'_W_shape'], [name+'_W']),
1355+
# get R
1356+
make_node('Add', [name+'_mul0', name+'_state_size^2'], [name+'_add0']),
1357+
make_node('Slice', [param, name+'_mul0', name+'_add0'], [name+'_R_1d']),
1358+
make_node('Reshape', [name+'_R_1d', name+'_R_shape'], [name+'_R']),
1359+
# get B
1360+
make_node('Add', [name+'_add0', name+'_2*state_size'], [name+'_add1']),
1361+
make_node('Slice', [param, name+'_add0', name+'_add1'], [name+'_B_1d']),
1362+
make_node('Reshape', [name+'_B_1d', name+'_B_shape'], [name+'_B']),
1363+
# get seq_len
1364+
make_node('Tile', [name+'_seq_length', name+'_batch_size'], [name+'_seq_len_']),
1365+
make_node("Cast", [name+'_seq_len_'], [name+"_seq_len"], to=int(TensorProto.INT32)),
1366+
# compute RNN
1367+
make_node('RNN', [data, name+'_W', name+'_R', name+'_B', name+'_seq_len', initial_h],
1368+
[name+'0_', name+'1'], hidden_size=state_size, activations=activations),
1369+
make_node('Squeeze', [name+'0_', name+'_1'], [name]),
1370+
]
1371+
else:
1372+
raise NotImplementedError('Currently RNN onnx export only supports num_layers equals to 1 or 2')
12831373
else:
12841374
raise NotImplementedError(f"Currently RNN onnx export does not support {mode} mode")
12851375
return nodes

0 commit comments

Comments
 (0)