@@ -1047,11 +1047,12 @@ def convert_RNN(node, **kwargs):
10471047 nodes = []
10481048
10491049 mode = str (attrs .get ('mode' ))
1050+ create_tensor ([0 ], name + '_0' , kwargs ['initializer' ])
10501051 create_tensor ([1 ], name + '_1' , kwargs ['initializer' ])
1052+
10511053 if mode == 'lstm' :
10521054 initial_c = input_nodes [3 ]
10531055 if num_layers == 2 :
1054- create_tensor ([0 ], name + '_0' , kwargs ['initializer' ])
10551056 create_tensor ([8 * state_size ], name + '_8*state_size' , kwargs ['initializer' ])
10561057 create_tensor ([4 * state_size * state_size ], name + '_4*state_size^2' , kwargs ['initializer' ])
10571058 create_tensor ([1 , 4 * state_size , state_size ], name + '_WR_shape' , kwargs ['initializer' ])
@@ -1123,7 +1124,6 @@ def convert_RNN(node, **kwargs):
11231124 make_node ('Concat' , [name + '_lstm0_c' , name + '_lstm1_c' ], [name + '2' ], axis = 0 ),
11241125 ]
11251126 elif num_layers == 1 :
1126- create_tensor ([0 ], name + '_0' , kwargs ['initializer' ])
11271127 create_tensor ([4 * state_size ], name + '_4*state_size' , kwargs ['initializer' ])
11281128 create_tensor ([8 * state_size ], name + '_8*state_size' , kwargs ['initializer' ])
11291129 create_tensor ([4 * state_size * state_size ], name + '_4*state_size^2' , kwargs ['initializer' ])
@@ -1167,7 +1167,6 @@ def convert_RNN(node, **kwargs):
11671167
11681168 elif mode == 'gru' :
11691169 if num_layers == 2 :
1170- create_tensor ([0 ], name + '_0' , kwargs ['initializer' ])
11711170 create_tensor ([6 * state_size ], name + '_6*state_size' , kwargs ['initializer' ])
11721171 create_tensor ([3 * state_size * state_size ], name + '_3*state_size^2' , kwargs ['initializer' ])
11731172 create_tensor ([1 , 3 * state_size , state_size ], name + '_WR_shape' , kwargs ['initializer' ])
@@ -1238,7 +1237,6 @@ def convert_RNN(node, **kwargs):
12381237 ]
12391238
12401239 elif num_layers == 1 :
1241- create_tensor ([0 ], name + '_0' , kwargs ['initializer' ])
12421240 create_tensor ([3 * state_size ], name + '_3*state_size' , kwargs ['initializer' ])
12431241 create_tensor ([6 * state_size ], name + '_6*state_size' , kwargs ['initializer' ])
12441242 create_tensor ([3 * state_size * state_size ], name + '_3*state_size^2' , kwargs ['initializer' ])
@@ -1272,14 +1270,106 @@ def convert_RNN(node, **kwargs):
12721270 # get seq_len
12731271 make_node ('Tile' , [name + '_seq_length' , name + '_batch_size' ], [name + '_seq_len_' ]),
12741272 make_node ("Cast" , [name + '_seq_len_' ], [name + "_seq_len" ], to = int (TensorProto .INT32 )),
1275- # compute LSTM
1273+ # compute GRU
12761274 make_node ('GRU' , [data , name + '_W' , name + '_R' , name + '_B' , name + '_seq_len' , initial_h ],
12771275 [name + '0_' , name + '1' ], hidden_size = state_size , linear_before_reset = 1 ),
12781276 make_node ('Squeeze' , [name + '0_' , name + '_1' ], [name ]),
12791277 ]
12801278 else :
12811279 raise NotImplementedError ('Currently RNN onnx export only supports num_layers equals to 1 or 2' )
12821280
1281+ elif mode in ['rnn_tanh' , 'rnn_relu' ]:
1282+ activations = ['Tanh' ]
1283+ if mode == 'rnn_relu' :
1284+ activations = ['Relu' ]
1285+ if num_layers == 2 :
1286+ create_tensor ([2 * state_size ], name + '_2*state_size' , kwargs ['initializer' ])
1287+ create_tensor ([state_size * state_size ], name + '_state_size^2' , kwargs ['initializer' ])
1288+ create_tensor ([1 , state_size , state_size ], name + '_WR_shape' , kwargs ['initializer' ])
1289+ create_tensor ([1 , 2 * state_size ], name + '_B_shape' , kwargs ['initializer' ])
1290+ create_tensor ([4 * state_size * state_size ], name + '_WR_offset' , kwargs ['initializer' ])
1291+
1292+ nodes += [
1293+ make_node ('Shape' , [data ], [name + '_data_shape' ]),
1294+ make_node ('Split' , [name + '_data_shape' ], [name + '_seq_length' , name + '_batch_size' , name + '_input_size' ]),
1295+
1296+ # Layer 0
1297+ # get W
1298+ make_node ('Slice' , [param , name + '_0' , name + '_state_size^2' ], [name + '_W0_1d' ]),
1299+ make_node ('Reshape' , [name + '_W0_1d' , name + '_WR_shape' ], [name + '_W0' ]),
1300+ # get R
1301+ make_node ('Add' , [name + '_state_size^2' , name + '_state_size^2' ], [name + '_R0_offset' ]),
1302+ make_node ('Slice' , [param , name + '_state_size^2' , name + '_R0_offset' ], [name + '_R0_1d' ]),
1303+ make_node ('Reshape' , [name + '_R0_1d' , name + '_WR_shape' ], [name + '_R0' ]),
1304+ # get B
1305+ make_node ('Add' , [name + '_WR_offset' , name + '_2*state_size' ], [name + '_B0_offset' ]),
1306+ make_node ('Slice' , [param , name + '_WR_offset' , name + '_B0_offset' ], [name + '_B0_1d' ]),
1307+ make_node ('Reshape' , [name + '_B0_1d' , name + '_B_shape' ], [name + '_B0' ]),
1308+ # get initial states
1309+ make_node ('Split' , [initial_h ], [name + '_initial_h0' , name + '_initial_h1' ], axis = 0 ),
1310+ # get seq_len
1311+ make_node ('Tile' , [name + '_seq_length' , name + '_batch_size' ], [name + '_seq_len_' ]),
1312+ make_node ("Cast" , [name + '_seq_len_' ], [name + "_seq_len" ], to = int (TensorProto .INT32 )),
1313+ # Layer 0 RNN
1314+ make_node ('RNN' , [data , name + '_W0' , name + '_R0' , name + '_B0' , name + '_seq_len' ,
1315+ name + '_initial_h0' ], [name + '_rnn0_out_' , name + '_rnn0_h' ],
1316+ hidden_size = state_size , activations = activations ),
1317+ make_node ('Squeeze' , [name + '_rnn0_out_' , name + '_1' ], [name + '_rnn0_out' ]),
1318+
1319+ # Layer 1
1320+ # get W
1321+ make_node ('Add' , [name + '_R0_offset' , name + '_state_size^2' ], [name + '_W1_offset' ]),
1322+ make_node ('Slice' , [param , name + '_R0_offset' , name + '_W1_offset' ], [name + '_W1_1d' ]),
1323+ make_node ('Reshape' , [name + '_W1_1d' , name + '_WR_shape' ], [name + '_W1' ]),
1324+ # get R
1325+ make_node ('Slice' , [param , name + '_W1_offset' , name + '_WR_offset' ], [name + '_R1_1d' ]),
1326+ make_node ('Reshape' , [name + '_R1_1d' , name + '_WR_shape' ], [name + '_R1' ]),
1327+ # get B
1328+ make_node ('Add' , [name + '_B0_offset' , name + '_2*state_size' ], [name + '_B1_offset' ]),
1329+ make_node ('Slice' , [param , name + '_B0_offset' , name + '_B1_offset' ], [name + '_B1_1d' ]),
1330+ make_node ('Reshape' , [name + '_B1_1d' , name + '_B_shape' ], [name + '_B1' ]),
1331+ # Layer 1 RNN
1332+ make_node ('RNN' , [name + '_rnn0_out' , name + '_W1' , name + '_R1' , name + '_B1' , name + '_seq_len' ,
1333+ name + '_initial_h1' ], [name + '_rnn1_out_' , name + '_rnn1_h' ],
1334+ hidden_size = state_size , activations = activations ),
1335+ make_node ('Squeeze' , [name + '_rnn1_out_' , name + '_1' ], [name ]),
1336+ make_node ('Concat' , [name + '_rnn0_h' , name + '_rnn1_h' ], [name + '1' ], axis = 0 )
1337+ ]
1338+
1339+ elif num_layers == 1 :
1340+ create_tensor ([state_size ], name + '_state_size' , kwargs ['initializer' ])
1341+ create_tensor ([2 * state_size ], name + '_2*state_size' , kwargs ['initializer' ])
1342+ create_tensor ([state_size * state_size ], name + '_state_size^2' , kwargs ['initializer' ])
1343+ create_tensor ([1 , state_size , state_size ], name + '_R_shape' , kwargs ['initializer' ])
1344+ create_tensor ([1 , 2 * state_size ], name + '_B_shape' , kwargs ['initializer' ])
1345+
1346+ nodes += [
1347+ make_node ('Shape' , [data ], [name + '_data_shape' ]),
1348+ make_node ('Split' , [name + '_data_shape' ], [name + '_seq_length' , name + '_batch_size' ,
1349+ name + '_input_size' ], name = 'split0' ),
1350+ # get W
1351+ make_node ('Mul' , [name + '_state_size' , name + '_input_size' ], [name + '_mul0' ]),
1352+ make_node ('Slice' , [param , name + '_0' , name + '_mul0' ], [name + '_W_1d' ]),
1353+ make_node ('Concat' , [name + '_1' , name + '_state_size' , name + '_input_size' ], [name + '_W_shape' ], axis = 0 ),
1354+ make_node ('Reshape' , [name + '_W_1d' , name + '_W_shape' ], [name + '_W' ]),
1355+ # get R
1356+ make_node ('Add' , [name + '_mul0' , name + '_state_size^2' ], [name + '_add0' ]),
1357+ make_node ('Slice' , [param , name + '_mul0' , name + '_add0' ], [name + '_R_1d' ]),
1358+ make_node ('Reshape' , [name + '_R_1d' , name + '_R_shape' ], [name + '_R' ]),
1359+ # get B
1360+ make_node ('Add' , [name + '_add0' , name + '_2*state_size' ], [name + '_add1' ]),
1361+ make_node ('Slice' , [param , name + '_add0' , name + '_add1' ], [name + '_B_1d' ]),
1362+ make_node ('Reshape' , [name + '_B_1d' , name + '_B_shape' ], [name + '_B' ]),
1363+ # get seq_len
1364+ make_node ('Tile' , [name + '_seq_length' , name + '_batch_size' ], [name + '_seq_len_' ]),
1365+ make_node ("Cast" , [name + '_seq_len_' ], [name + "_seq_len" ], to = int (TensorProto .INT32 )),
1366+ # compute RNN
1367+ make_node ('RNN' , [data , name + '_W' , name + '_R' , name + '_B' , name + '_seq_len' , initial_h ],
1368+ [name + '0_' , name + '1' ], hidden_size = state_size , activations = activations ),
1369+ make_node ('Squeeze' , [name + '0_' , name + '_1' ], [name ]),
1370+ ]
1371+ else :
1372+ raise NotImplementedError ('Currently RNN onnx export only supports num_layers equals to 1 or 2' )
12831373 else :
12841374 raise NotImplementedError (f"Currently RNN onnx export does not support { mode } mode" )
12851375 return nodes
0 commit comments