diff --git a/python/tvm/contrib/target/onnx.py b/python/tvm/contrib/target/onnx.py index b839af669fe6..6f8aab23cde1 100644 --- a/python/tvm/contrib/target/onnx.py +++ b/python/tvm/contrib/target/onnx.py @@ -655,7 +655,7 @@ def convert_attributes(cls, attrs): class Cast(OpConverter): - """ Operator converter for Cast.""" + """Operator converter for Cast.""" @classmethod def convert_attributes(cls, attrs): diff --git a/python/tvm/relay/frontend/common.py b/python/tvm/relay/frontend/common.py index 9c53b59f9998..7f67ed404de9 100755 --- a/python/tvm/relay/frontend/common.py +++ b/python/tvm/relay/frontend/common.py @@ -624,3 +624,128 @@ def to_int_list(np_array): cause problems in relay/TOPI. """ return [int(x) for x in np_array] + + +def unbind(data, axis=0): + """ + Unbind was taken from Pytorch frontend. The operation removes a tensor dimension + and returns a tuple of all slices along a given dimension, with specified axis removed. + TODO (vvchernov): It needs such operation on relay side to reduce time consumption + on squeeze operation. + + Parameters + ---------- + data : relay.Expr + Input tensor + axis : int + Axis along which tensor is split. + Returns + ------- + result : List[relay.Expr] + The sequence of computed tensors + """ + shape = infer_shape(data) + if axis >= len(shape): + msg = "Please check input dim, it shouldn't be greater than or equal to rank." + raise AttributeError(msg) + + selections = shape[axis] + res_split = _op.split(data, selections, axis) + ret = [] + for i in range(selections): + ret.append(_op.squeeze(res_split[i], axis=[axis])) + return _expr.TupleWrapper(_expr.Tuple(ret), selections) + + +def lstm_cell( + input_seqs, + hidden_state, + cell_state, + w_inp, + w_hid, + b_inp=None, + b_hid=None, + proj=None, + p_i=None, + p_f=None, + p_o=None, + f_act=_op.sigmoid, + g_act=_op.tanh, + h_act=_op.tanh, + backwards=False, +): + """ + Common implementation of LSTM cell for all frontends of TVM + TODO (vvchernov): currently it is used by onnx and pytorch. Extend for other frontends + + Parameters + ---------- + input_seqs : List[relay.Expr] + The sequence of input tensors + Input tensor should be 2d while issue #8412 is not resolved + Shape = (batch, feature_size) + hidden_state : relay.Expr + Hidden state. shape = (batch, hidden_size) + cell_state : relay.Expr + Cell state. shape = (batch, hidden_size) + w_inp, w_hid : relay.Expr + weight matrices. wi shape = (4 * hidden_size, feature_size) + wh shape = (4 * hidden_size, hidden_size or proj_size) + NOTE: wi = (w_ii|w_if|w_ig|w_io) for input, forget, cell and output gates. + The order is important for correct LSTM calculation! + b_inp, b_hid : relay.Expr + bias matrices. The same order of internal parts as for weights. shape = (4 * hidden_size) + proj : relay.Expr + projection matrix. shape = (proj_size, hidden_size) + p_i, p_f, p_o : relay.Expr + peephole LSTM matrices. shape = (batch, hidden_size) + f_act, g_act, h_act : relay.op + activation funtions + backwards : bool + Flag for reverse pass of LSTM + + Returns + ------- + result : List[relay.Expr], relay.Expr, relay.Expr + The sequence of computed result, final hidden and cell state + """ + + outputs_list = [] + for x_t in input_seqs if not backwards else reversed(input_seqs): + # x_t shape = (batch, feature size), step shape = (batch, feature size + hidden_size) + step = _op.concatenate([x_t, hidden_state], axis=1) + cat_w = _op.concatenate([w_inp, w_hid], axis=1) + # Instead of nn.dense(x_t, w_inp) + nn.dense(hidden_state, w_hid) + # nn.dense(step, cat_w) is used + # gates shape = (batch, 4 * hidden_size) + gates = _op.nn.dense(step, cat_w) + # Add biases + if b_inp is not None: + gates += b_inp + if b_hid is not None: + gates += b_hid + # any gate shape = (batch, hidden_size) + inp_gate, fgt_gate, cell_gate, otp_gate = _op.split(gates, 4, axis=-1) + + if p_i is not None and p_f is not None: + inp_gate = f_act(inp_gate + p_i * cell_state) + fgt_gate = f_act(fgt_gate + p_f * cell_state) + else: + inp_gate = f_act(inp_gate) + fgt_gate = f_act(fgt_gate) + + cell_gate = g_act(cell_gate) + cell_state = fgt_gate * cell_state + inp_gate * cell_gate + if p_o is not None: + otp_gate = f_act(otp_gate + p_o * cell_state) + else: + otp_gate = f_act(otp_gate) + + hidden_state = otp_gate * h_act(cell_state) + + if proj is not None: + hidden_state = _op.nn.dense(hidden_state, proj) + + outputs_list.append(hidden_state) # [seq_num, (batch, hidden_size)] + + return outputs_list, hidden_state, cell_state diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 42bde838859a..adbbaf9ce885 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -46,6 +46,8 @@ infer_type, infer_value, new_var, + unbind, + lstm_cell, ) __all__ = ["from_onnx"] @@ -2142,58 +2144,44 @@ class LSTM(RNN): """Operator converter for LSTM""" @classmethod - def generate_lstm( - cls, X_steps, H_t, C_t, W, R, B, p_i, p_f, p_o, f_act, g_act, h_act, backwards=False + def bidir_lstm_cell( + cls, + input_seqs, + weight_dicts, + acts, ): - """Create an unrolled lstm loop. - - See https://github.com/onnx/onnx/blob/master/docs/Operators.md for math. """ - h_list = [] - seq_length = len(X_steps) - for i in range(seq_length): - step = X_steps[i] if not backwards else X_steps[seq_length - (i + 1)] - step = _op.squeeze(step, axis=[0]) - gates = _op.nn.dense(step, W) + _op.nn.dense(H_t, R) - if B is not None: - WB, RB = _op.split(B, 2) - gates += WB + RB - i, o, f, c = _op.split(gates, 4, axis=-1) - - if p_i != 0: - i = f_act(i + p_i * C_t) - else: - i = f_act(i) - - if p_f != 0: - f = f_act(f + p_f * C_t) - else: - f = f_act(f) - - c = g_act(c) - C = f * C_t + i * c - if p_o != 0: - o = f_act(o + p_o * C) - else: - o = f_act(o) - - H = o * h_act(C) - - H_t = H - C_t = C - h_list.append(_op.expand_dims(H, axis=0)) + Bidirectional LSTM cell + """ + seq_len = len(input_seqs) + forward_outputs, fw_H_t, fw_C_t = lstm_cell( + input_seqs, + **weight_dicts[0], + f_act=acts[0], + g_act=acts[1], + h_act=acts[2], + ) - if backwards: - # Canonical view is hidden states from the first token not last - h_list = h_list[::-1] + reverse_outputs, rev_H_t, rev_C_t = lstm_cell( + input_seqs, + **weight_dicts[1], + f_act=acts[3], + g_act=acts[4], + h_act=acts[5], + backwards=True, + ) - # Concatenate outputs and add back in direction axis. - concatenated = _op.concatenate(h_list, 0) - output = _op.expand_dims(concatenated, axis=1) - H_t = _op.expand_dims(H_t, axis=0) - C_t = _op.expand_dims(C_t, axis=0) + final_outputs = [] + for i in range(seq_len): + final_outputs.append( + _op.stack([forward_outputs[i], reverse_outputs[seq_len - 1 - i]], axis=0) + ) - return output, H_t, C_t + return ( + _op.stack(final_outputs, axis=0), + _op.stack([fw_H_t, rev_H_t], axis=0), + _op.stack([fw_C_t, rev_C_t], axis=0), + ) @classmethod def _impl_v7(cls, inputs, attr, params): @@ -2224,12 +2212,6 @@ def _impl_v7(cls, inputs, attr, params): Hp_0 = _op.zeros((num_directions, batch_size, hidden_size), W_dtype) if Cp_0 is None: Cp_0 = _op.zeros((num_directions, batch_size, hidden_size), W_dtype) - if Bp is None: - Bp = _op.zeros((num_directions, hidden_size * 8), W_dtype) - if Pp is not None: - p_i, p_o, p_f = _op.split(Pp, 3, axis=1) - else: - p_i = p_o = p_f = _op.zeros((num_directions, hidden_size), W_dtype) if "activations" in attr: activations = attr["activations"] @@ -2260,53 +2242,67 @@ def _impl_v7(cls, inputs, attr, params): else: acts = [_op.sigmoid, _op.tanh, _op.tanh] * num_directions - X_steps = _op.split(X, indices_or_sections=X_shape[0], axis=0) - result_output = [] - result_H = [] - result_C = [] + # TODO (vvchernov): It can be replaced by _op.split if issue #8412 is resolved + X_steps = unbind(X, axis=0) H_ts = _op.split(Hp_0, num_directions) C_ts = _op.split(Cp_0, num_directions) Ws = _op.split(Wp, num_directions) Rs = _op.split(Rp, num_directions) - Bs = _op.split(Bp, num_directions) - p_is = _op.split(p_i, num_directions) - p_fs = _op.split(p_f, num_directions) - p_os = _op.split(p_o, num_directions) - for i in range(num_directions): - H_t = _op.squeeze(H_ts[i], axis=[0]) - C_t = _op.squeeze(C_ts[i], axis=[0]) - W = _op.squeeze(Ws[i], axis=[0]) - R = _op.squeeze(Rs[i], axis=[0]) - B = _op.squeeze(Bs[i], axis=[0]) - p_i = _op.squeeze(p_is[i], axis=[0]) - p_f = _op.squeeze(p_fs[i], axis=[0]) - p_o = _op.squeeze(p_os[i], axis=[0]) - f_act, g_act, h_act = acts[i * 3 : (i + 1) * 3] - output, H, C = LSTM.generate_lstm( - X_steps=X_steps, - H_t=H_t, - C_t=C_t, - W=W, - R=R, - B=B, - p_i=p_i, - p_f=p_f, - p_o=p_o, - f_act=f_act, - g_act=g_act, - h_act=h_act, - backwards=i == 1, - ) + if Bp is not None: + Bs = _op.split(Bp, num_directions) + if Pp is not None: + p_i, p_o, p_f = _op.split(Pp, 3, axis=1) - result_output.append(output) - result_H.append(H) - result_C.append(C) + p_is = _op.split(p_i, num_directions) + p_fs = _op.split(p_f, num_directions) + p_os = _op.split(p_o, num_directions) - output = _op.concatenate(result_output, axis=1) - H = _op.concatenate(result_H, axis=0) - C = _op.concatenate(result_C, axis=0) + weights_dicts = [] + for i in range(num_directions): + weights_dict = {} + + weights_dict["hidden_state"] = _op.squeeze(H_ts[i], axis=[0]) + weights_dict["cell_state"] = _op.squeeze(C_ts[i], axis=[0]) + + # Weights permutation: onnx format i-o-f-c, lstm cell format i-f-c-o + mati, mato, matf, matc = _op.split(_op.squeeze(Ws[i], axis=[0]), 4) + weights_dict["w_inp"] = _op.concatenate([mati, matf, matc, mato], axis=0) + mati, mato, matf, matc = _op.split(_op.squeeze(Rs[i], axis=[0]), 4) + weights_dict["w_hid"] = _op.concatenate([mati, matf, matc, mato], axis=0) + if Bp is not None: + Bi, Bh = _op.split(Bs[i], 2, -1) + mati, mato, matf, matc = _op.split(_op.squeeze(Bi, axis=[0]), 4) + weights_dict["b_inp"] = _op.concatenate([mati, matf, matc, mato], axis=0) + mati, mato, matf, matc = _op.split(_op.squeeze(Bh, axis=[0]), 4) + weights_dict["b_hid"] = _op.concatenate([mati, matf, matc, mato], axis=0) + if Pp is not None: + weights_dict["p_i"] = _op.squeeze(p_is[i], axis=[0]) + weights_dict["p_f"] = _op.squeeze(p_fs[i], axis=[0]) + weights_dict["p_o"] = _op.squeeze(p_os[i], axis=[0]) + weights_dicts.append(weights_dict) + + if num_directions == 2: + output, H, C = LSTM.bidir_lstm_cell( + input_seqs=X_steps, + weight_dicts=weights_dicts, + acts=acts, + ) + else: + # outputs shape = [seqs_num, (batch_size, hidden_size)] + outputs, H, C = lstm_cell( + input_seqs=X_steps, + **weights_dicts[0], + f_act=acts[0], + g_act=acts[1], + h_act=acts[2], + ) + + # output shape = (seqs_num, num_directions, batch_size, hidden_size) + output = _op.expand_dims(_op.stack(outputs, axis=0), axis=1) + H = _op.expand_dims(H, axis=0) + C = _op.expand_dims(C, axis=0) return _expr.TupleWrapper(_expr.Tuple((output, H, C)), 3) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 33cb83b883bc..321bdc6d62e5 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -39,7 +39,7 @@ from ..prelude import Prelude, StaticTensorArrayOps from ..ty import Any, TensorType, TupleType from . import qnn_torch -from .common import AttrCvt, get_relay_op +from .common import AttrCvt, get_relay_op, unbind, lstm_cell from .common import infer_value as _infer_value from .common import infer_shape as _infer_shape from .common import infer_value_simulated as _infer_value_simulated @@ -2094,21 +2094,8 @@ def deform_conv2d(self, inputs, input_types): def unbind(self, inputs, input_types): data = inputs[0] - dim = int(inputs[1]) - ishapes = self.infer_shape(data) - if dim >= len(ishapes): - msg = "Please check input dim, it shouldn't be greater than or equal to rank." - raise AttributeError(msg) - - selections = ishapes[dim] - res_split = _op.split(data, selections, dim) - # squeeze each split piece to get same shape as aten::unbind - # TODO (yongwww): add new op to avoid the squeeze overhead - ret = [] - for i in range(selections): - ret.append(_op.transform.squeeze(res_split[i], axis=[dim])) - ret = _expr.TupleWrapper(_expr.Tuple(ret), selections) - return ret + axis = int(inputs[1]) + return unbind(data, axis) def shape_as_tensor(self, inputs, input_types): is_symbolic_shape = False @@ -2135,7 +2122,7 @@ def nonzero(self, inputs, input_types, is_numpy_style=False): data = inputs[0] ret = _op.transform.argwhere(data) if is_numpy_style or (len(inputs) > 1 and inputs[1]): - return self.unbind([ret, 1], None) + return unbind(ret, 1) return ret def nonzero_numpy(self, inputs, input_types): @@ -2330,102 +2317,65 @@ def flip(self, inputs, input_types): axis = inputs[1] return _op.transform.reverse(data, axis=axis[0]) - def lstm_cell(self, input_seqs, hidden, weights, has_proj=False): - if has_proj: - assert len(weights) == 5 - else: - assert len(weights) == 4 - outputs_list = [] - # Default activations types - f_act = _op.sigmoid - g_act = _op.tanh - h_act = _op.tanh - - # Input hiddens - H_t = hidden[0] # (batch, hidden_size) - C_t = hidden[1] # (batch, hidden_size) - for x_t in input_seqs: - # x_t shape = (batch, feature size) - # gates shape = (batch, 4 * hidden_size) - gates = _op.nn.dense(x_t, weights[0]) + _op.nn.dense(H_t, weights[1]) - # Add biases - if weights[2] is not None: - gates += weights[2] - if weights[3] is not None: - gates += weights[3] - i, f, c, o = _op.split(gates, 4, axis=-1) # (batch, hidden_size) - - i = f_act(i) - f = f_act(f) - c = g_act(c) - o = f_act(o) - - C = f * C_t + i * c - H = o * h_act(C) - - if has_proj: - H = _op.nn.dense(H, weights[4]) - - H_t = H - C_t = C - outputs_list.append(H) # [seq_num, (batch, hidden_size)] - hidden_outputs = (H_t, C_t) - - return (outputs_list, hidden_outputs) + def bidir_lstm_cell( + self, + input_seqs, + weights_dicts, + ): + """ + Bidirectional LSTM cell + """ + seq_len = len(input_seqs) + forward_outputs, fw_H_t, fw_C_t = lstm_cell( + input_seqs, + **weights_dicts[0], + ) - def bidir_lstm_cell(self, input_seq, hidden_pair, weights_pair, has_proj=False): - fw_outputs = self.lstm_cell(input_seq, hidden_pair[0], weights_pair[0], has_proj) + reverse_outputs, rev_H_t, rev_C_t = lstm_cell( + input_seqs, + **weights_dicts[1], + backwards=True, + ) - rev_input_seq = [] - seq_len = len(input_seq) + final_outputs = [] for i in range(seq_len): - rev_input_seq.append(input_seq[seq_len - 1 - i]) # [seq_num, (batch, hidden_size)] - rev_outputs = self.lstm_cell(rev_input_seq, hidden_pair[1], weights_pair[1], has_proj) - - final_outputs = [] # [seq_num, (batch, 2 * hidden_size)] - for j in range(seq_len): final_outputs.append( - _op.concatenate([fw_outputs[0][j], rev_outputs[0][seq_len - 1 - j]], -1) + _op.concatenate([forward_outputs[i], reverse_outputs[seq_len - 1 - i]], axis=-1) ) - return final_outputs, (fw_outputs[1], rev_outputs[1]) - - def lstm_layers( - self, input_data, hiddens, weights, bidirectional, dtype, dropout_p=0.0, has_proj=False - ): - hidden_layers_num = len(hiddens) - assert len(weights) == hidden_layers_num + return final_outputs, (fw_H_t, fw_C_t), (rev_H_t, rev_C_t) + def lstm_layers(self, input_data, layer_weights_dicts, bidirectional, dtype, dropout_p=0.0): + """ + Methods iterates layers for Stacked LSTM + """ + layers_num = len(layer_weights_dicts) # split input sequence to samples set - input_seqs = self.unbind((input_data, 0), dtype) # [seq_num, (batch, feature_size)] + input_seqs = unbind(input_data, 0) # [seq_num, (batch, feature_size)] output_hiddens = [] - for k in range(hidden_layers_num): - hiddens_input = hiddens[k] - weights_input = weights[k] - - outputs = ( - self.bidir_lstm_cell(input_seqs, hiddens_input, weights_input, has_proj) - if bidirectional - else self.lstm_cell(input_seqs, hiddens_input, weights_input, has_proj) - ) - - output_hiddens.append(outputs[1]) + for i in range(layers_num): + weights_dicts = layer_weights_dicts[i] # input_seqs shape = [seq_num, (batch, feature_size)] or # [seq_num, (batch, 2*feature_size)] for bidirectional - input_seqs = outputs[0] + if bidirectional: + input_seqs, H_t, C_t = self.bidir_lstm_cell(input_seqs, weights_dicts) + else: + input_seqs, H_t, C_t = lstm_cell(input_seqs, **weights_dicts[0]) + + output_hiddens.append((H_t, C_t)) # TODO (vvchernov): in pytorch implementation train is also checked # see https://github.com/pytorch/pytorch/blob/70c8daf43946b53af6493d058899ef952d27d339 # /aten/src/ATen/native/RNN.cpp#L1054 - if dropout_p != 0 and k < hidden_layers_num - 1: + if dropout_p != 0 and i < layers_num - 1: # for input in input_seqs: # input = _op.dropout(input, dropout_p) raise NotImplementedError("Dropout for LSTM has not been supported yet!") final_hiddens = [] if bidirectional: - for i in range(hidden_layers_num): - final_hiddens.append(output_hiddens[i][0]) - final_hiddens.append(output_hiddens[i][1]) + for output_hidden in output_hiddens: + final_hiddens.append(output_hidden[0]) + final_hiddens.append(output_hidden[1]) else: final_hiddens = output_hiddens @@ -2513,52 +2463,6 @@ def lstm(self, inputs, input_types): else: assert weights_num == 2, "The weights number in layer is expected equal to 2" - weights = [] - if has_biases: - if bidirectional: - rsd = len(_weights) % (2 * weights_num) - assert rsd == 0, "got an incorrect number of LSTM weights" - for i in range(0, len(_weights), 2 * weights_num): - fw_weights = [] - rev_weights = [] - for j in range(weights_num): - fw_weights.append(_weights[i + j]) - rev_weights.append(_weights[i + j + weights_num]) - weights.append((fw_weights, rev_weights)) - else: - assert len(_weights) % weights_num == 0, "got an incorrect number of LSTM weights" - for i in range(0, len(_weights), weights_num): - fw_weights = [] - for j in range(weights_num): - fw_weights.append(_weights[i + j]) - weights.append(fw_weights) - else: - if bidirectional: - rsd = len(_weights) % (2 * weights_num) - assert rsd == 0, "got an incorrect number of LSTM weights" - for i in range(0, len(_weights), 2 * weights_num): - fw_weights = [] - rev_weights = [] - k = i + weights_num - if has_proj: - fw_weights = [_weights[i], _weights[i + 1], None, None, _weights[i + 2]] - rev_weights = [_weights[k], _weights[k + 1], None, None, _weights[k + 2]] - else: - fw_weights = [_weights[i], _weights[i + 1], None, None] - rev_weights = [_weights[k], _weights[k + 1], None, None] - weights.append((fw_weights, rev_weights)) - else: - assert len(_weights) % weights_num == 0, "got an incorrect number of LSTM weights" - for i in range(0, len(_weights), weights_num): - if has_proj: - fw_weights = [_weights[i], _weights[i + 1], None, None, _weights[i + 2]] - else: - fw_weights = [_weights[i], _weights[i + 1], None, None] - weights.append(fw_weights) - assert ( - len(weights) == num_layers - ), "For stacked LSTM number of weights tuples should be the same as number of layers!" - X = _op.transpose(_X, (1, 0, 2)) if batch_first else _X # TODO (vvchernov): Which data type should be used? from input or weights? # Instead of it _infer_type(X).checked_type.dtype can be used @@ -2580,31 +2484,78 @@ def lstm(self, inputs, input_types): for i in range(hidden_layers_num): layers_h.append(h_0) else: - layers_h = self.unbind((h_0, 0), X_dtype) + layers_h = unbind(h_0, 0) if c_0 is None: c_0 = _op.zeros((batch_size, hidden_size), X_dtype) for i in range(hidden_layers_num): layers_c.append(c_0) else: - layers_c = self.unbind((c_0, 0), X_dtype) + layers_c = unbind(c_0, 0) - hiddens = [] - for i in range(num_layers): + layer_weights_dicts = [] + k = 0 # layer counter + if has_biases: + names = ["hidden_state", "cell_state", "w_inp", "w_hid", "b_inp", "b_hid"] if bidirectional: - hiddens.append( - ((layers_h[2 * i], layers_c[2 * i]), (layers_h[2 * i + 1], layers_c[2 * i + 1])) - ) + rsd = len(_weights) % (2 * weights_num) + assert rsd == 0, "got an incorrect number of LSTM weights" + for i in range(0, len(_weights), 2 * weights_num): + fw_tensors = [layers_h[2 * k], layers_c[2 * k], *_weights[i : i + 4]] + fw_weights_dict = dict(zip(names, fw_tensors)) + if has_proj: + fw_weights_dict["proj"] = _weights[i + 4] + j = i + weights_num + rev_tensors = [layers_h[2 * k + 1], layers_c[2 * k + 1], *_weights[j : j + 4]] + rev_weights_dict = dict(zip(names, rev_tensors)) + if has_proj: + rev_weights_dict["proj"] = _weights[j + 4] + layer_weights_dicts.append([fw_weights_dict, rev_weights_dict]) + k += 1 else: - hiddens.append((layers_h[i], layers_c[i])) + assert len(_weights) % weights_num == 0, "got an incorrect number of LSTM weights" + for i in range(0, len(_weights), weights_num): + fw_tensors = [layers_h[k], layers_c[k], *_weights[i : i + 4]] + fw_weights_dict = dict(zip(names, fw_tensors)) + if has_proj: + fw_weights_dict["proj"] = _weights[i + 4] + layer_weights_dicts.append([fw_weights_dict]) + k += 1 + else: + names = ["hidden_state", "cell_state", "w_inp", "w_hid"] + if bidirectional: + rsd = len(_weights) % (2 * weights_num) + assert rsd == 0, "got an incorrect number of LSTM weights" + for i in range(0, len(_weights), 2 * weights_num): + fw_tensors = [layers_h[2 * k], layers_c[2 * k], *_weights[i : i + 2]] + fw_weights_dict = dict(zip(names, fw_tensors)) + if has_proj: + fw_weights_dict["proj"] = _weights[i + 2] + j = i + weights_num + rev_tensors = [layers_h[2 * k + 1], layers_c[2 * k + 1], *_weights[j : j + 2]] + rev_weights_dict = dict(zip(names, rev_tensors)) + if has_proj: + rev_weights_dict["proj"] = _weights[j + 2] + layer_weights_dicts.append([fw_weights_dict, rev_weights_dict]) + k += 1 + else: + assert len(_weights) % weights_num == 0, "got an incorrect number of LSTM weights" + for i in range(0, len(_weights), weights_num): + fw_tensors = [layers_h[k], layers_c[k], *_weights[i : i + 2]] + fw_weights_dict = dict(zip(names, fw_tensors)) + if has_proj: + fw_weights_dict["proj"] = _weights[i + 2] + layer_weights_dicts.append([fw_weights_dict]) + k += 1 + assert ( + len(layer_weights_dicts) == num_layers and k == num_layers + ), "For stacked LSTM number of weights sets should be the same as number of layers!" outputs = self.lstm_layers( X, - hiddens, - weights, + layer_weights_dicts, bidirectional, dtype=X_dtype, dropout_p=dropout_p, - has_proj=has_proj, ) # output shape = (seq_num, batch, hidden_size) or