From df1b9945c164ecd9f9fbc562fb7e94612abdc3ef Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Fri, 23 Jul 2021 16:45:15 +0300 Subject: [PATCH 01/21] fuse dence sum --- python/tvm/relay/frontend/pytorch.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 33cb83b883bc..f9dc5477428f 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -2345,9 +2345,12 @@ def lstm_cell(self, input_seqs, hidden, weights, has_proj=False): H_t = hidden[0] # (batch, hidden_size) C_t = hidden[1] # (batch, hidden_size) for x_t in input_seqs: - # x_t shape = (batch, feature size) + # x_t shape = (batch, feature size), step shape = (batch, feature size + hidden_size) + step = _op.concatenate([x_t, H_t], axis=1) + W = _op.concatenate([weights[0], weights[1]], axis = 1) + # Instead of _op.nn.dense(x_t, weights[0]) + _op.nn.dense(H_t, weights[1]) we have _op.nn.dense(step, W) # gates shape = (batch, 4 * hidden_size) - gates = _op.nn.dense(x_t, weights[0]) + _op.nn.dense(H_t, weights[1]) + gates = _op.nn.dense(step, W) # Add biases if weights[2] is not None: gates += weights[2] From 2a016d94ee9e44cb5f436614a17a864e8c9f9929 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Fri, 23 Jul 2021 16:54:00 +0300 Subject: [PATCH 02/21] remove excess copying --- python/tvm/relay/frontend/pytorch.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index f9dc5477428f..9e73950abf57 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -2363,15 +2363,13 @@ def lstm_cell(self, input_seqs, hidden, weights, has_proj=False): c = g_act(c) o = f_act(o) - C = f * C_t + i * c - H = o * h_act(C) + C_t = f * C_t + i * c + H_t = o * h_act(C_t) if has_proj: - H = _op.nn.dense(H, weights[4]) + H_t = _op.nn.dense(H_t, weights[4]) - H_t = H - C_t = C - outputs_list.append(H) # [seq_num, (batch, hidden_size)] + outputs_list.append(H_t) # [seq_num, (batch, hidden_size)] hidden_outputs = (H_t, C_t) return (outputs_list, hidden_outputs) From d3bf383dc56ed03481ab89a55e857a97d1aa07c4 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Wed, 21 Jul 2021 17:19:17 +0300 Subject: [PATCH 03/21] dev LSTM in ONNX --- python/tvm/relay/frontend/onnx.py | 173 ++++++++++++++++++++++++++++++ 1 file changed, 173 insertions(+) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 42bde838859a..9cf429b4ca86 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -2311,6 +2311,179 @@ def _impl_v7(cls, inputs, attr, params): return _expr.TupleWrapper(_expr.Tuple((output, H, C)), 3) +class LSTM_dev(RNN): + """Operator converter for LSTM""" + + @classmethod + def generate_lstm( + cls, X_steps, H_t, C_t, W, R, B, p_i, p_f, p_o, f_act, g_act, h_act, backwards=False + ): + """Create an unrolled lstm loop. + + See https://github.com/onnx/onnx/blob/master/docs/Operators.md for math. + """ + h_list = [] + seq_length = len(X_steps) + for i in range(seq_length): + step = X_steps[i] if not backwards else X_steps[seq_length - (i + 1)] + step = _op.squeeze(step, axis=[0]) + gates = _op.nn.dense(step, W) + _op.nn.dense(H_t, R) + if B is not None: + WB, RB = _op.split(B, 2) + gates += WB + RB + i, o, f, c = _op.split(gates, 4, axis=-1) + + if p_i != 0: + i = f_act(i + p_i * C_t) + else: + i = f_act(i) + + if p_f != 0: + f = f_act(f + p_f * C_t) + else: + f = f_act(f) + + c = g_act(c) + C = f * C_t + i * c + if p_o != 0: + o = f_act(o + p_o * C) + else: + o = f_act(o) + + H = o * h_act(C) + + H_t = H + C_t = C + h_list.append(_op.expand_dims(H, axis=0)) + + if backwards: + # Canonical view is hidden states from the first token not last + h_list = h_list[::-1] + + # Concatenate outputs and add back in direction axis. + concatenated = _op.concatenate(h_list, 0) + output = _op.expand_dims(concatenated, axis=1) + H_t = _op.expand_dims(H_t, axis=0) + C_t = _op.expand_dims(C_t, axis=0) + + return output, H_t, C_t + + @classmethod + def _impl_v7(cls, inputs, attr, params): + # Unpack inputs, note that if optional and not provided then value will be None. + X = inputs[0] + Wp = inputs[1] + Rp = inputs[2] + Bp = inputs[3] + # Sequence length currently unused as it can be inferred from shapes. + # sequence_lens = inputs['sequence_lens'] + Hp_0 = inputs[5] + Cp_0 = inputs[6] + Pp = inputs[7] + + num_directions = infer_shape(Wp)[0] + W_dtype = infer_type(Wp).checked_type.dtype + + if num_directions not in [1, 2]: + raise ValueError("num_directions must be either 1 or 2!") + + X_shape = infer_shape(X) + hidden_size = infer_shape(Rp)[-1] + batch_size = X_shape[1] + + # Initialize state if not provided. + # Otherwise remove bidirectional axis. + if Hp_0 is None: + Hp_0 = _op.zeros((num_directions, batch_size, hidden_size), W_dtype) + if Cp_0 is None: + Cp_0 = _op.zeros((num_directions, batch_size, hidden_size), W_dtype) + if Bp is None: + Bp = _op.zeros((num_directions, hidden_size * 8), W_dtype) + if Pp is not None: + p_i, p_o, p_f = _op.split(Pp, 3, axis=1) + else: + p_i = p_o = p_f = _op.zeros((num_directions, hidden_size), W_dtype) + + if "activations" in attr: + activations = attr["activations"] + if len(activations) != 3 * num_directions: + raise NotImplementedError( + f"LSTM assumes 3 * num_directions activation functions are provided" + ) + alpha_loc = 0 + alphas = attr.get("activation_alpha", []) + if isinstance(alphas, float): + alphas = [alphas] + beta_loc = 0 + betas = attr.get("activation_beta", []) + if isinstance(betas, float): + betas = [betas] + acts = [] + for i in range(3 * num_directions): + alpha = None + beta = None + activation = activations[i] + if cls._activation_needs_alpha(activation) and len(alphas) > alpha_loc: + alpha = alphas[alpha_loc] + alpha_loc += 1 + if cls._activation_needs_beta(activation) and len(betas) > beta_loc: + beta = betas[beta_loc] + beta_loc += 1 + acts.append(cls._activation_helper(activation, alpha, beta)) + else: + acts = [_op.sigmoid, _op.tanh, _op.tanh] * num_directions + + X_steps = _op.split(X, indices_or_sections=X_shape[0], axis=0) + result_output = [] + result_H = [] + result_C = [] + + H_ts = _op.split(Hp_0, num_directions) + C_ts = _op.split(Cp_0, num_directions) + Ws = _op.split(Wp, num_directions) + Rs = _op.split(Rp, num_directions) + Bs = _op.split(Bp, num_directions) + p_is = _op.split(p_i, num_directions) + p_fs = _op.split(p_f, num_directions) + p_os = _op.split(p_o, num_directions) + for i in range(num_directions): + H_t = _op.squeeze(H_ts[i], axis=[0]) + C_t = _op.squeeze(C_ts[i], axis=[0]) + W = _op.squeeze(Ws[i], axis=[0]) + R = _op.squeeze(Rs[i], axis=[0]) + B = _op.squeeze(Bs[i], axis=[0]) + p_i = _op.squeeze(p_is[i], axis=[0]) + p_f = _op.squeeze(p_fs[i], axis=[0]) + p_o = _op.squeeze(p_os[i], axis=[0]) + + f_act, g_act, h_act = acts[i * 3 : (i + 1) * 3] + output, H, C = LSTM.generate_lstm( + X_steps=X_steps, + H_t=H_t, + C_t=C_t, + W=W, + R=R, + B=B, + p_i=p_i, + p_f=p_f, + p_o=p_o, + f_act=f_act, + g_act=g_act, + h_act=h_act, + backwards=i == 1, + ) + + result_output.append(output) + result_H.append(H) + result_C.append(C) + + output = _op.concatenate(result_output, axis=1) + H = _op.concatenate(result_H, axis=0) + C = _op.concatenate(result_C, axis=0) + + return _expr.TupleWrapper(_expr.Tuple((output, H, C)), 3) + + class GRU(RNN): """Operator convert for GRU""" From f737155400e031c5d063d7cd21b0c3bc95d69904 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Thu, 22 Jul 2021 22:27:04 +0300 Subject: [PATCH 04/21] alternative implementation of LSTM in onnx frontend. It is quicker than current one without tuning --- python/tvm/relay/frontend/onnx.py | 35 +++++++++++++++++-------------- 1 file changed, 19 insertions(+), 16 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 9cf429b4ca86..9d6250303194 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -23,6 +23,7 @@ import numpy as np import tvm from tvm.ir import IRModule +from tvm.relay.op.tensor import concatenate from tvm.topi.utils import get_const_tuple from ... import nd as _nd @@ -2316,7 +2317,7 @@ class LSTM_dev(RNN): @classmethod def generate_lstm( - cls, X_steps, H_t, C_t, W, R, B, p_i, p_f, p_o, f_act, g_act, h_act, backwards=False + cls, X_steps, H_t, C_t, W, R, WB, RB, p_i, p_f, p_o, f_act, g_act, h_act, backwards=False ): """Create an unrolled lstm loop. @@ -2328,11 +2329,9 @@ def generate_lstm( step = X_steps[i] if not backwards else X_steps[seq_length - (i + 1)] step = _op.squeeze(step, axis=[0]) gates = _op.nn.dense(step, W) + _op.nn.dense(H_t, R) - if B is not None: - WB, RB = _op.split(B, 2) + if WB is not None and RB is not None: gates += WB + RB i, o, f, c = _op.split(gates, 4, axis=-1) - if p_i != 0: i = f_act(i + p_i * C_t) else: @@ -2344,17 +2343,17 @@ def generate_lstm( f = f_act(f) c = g_act(c) - C = f * C_t + i * c + C_t = f * C_t + i * c if p_o != 0: - o = f_act(o + p_o * C) + o = f_act(o + p_o * C_t) else: o = f_act(o) - H = o * h_act(C) + H_t = o * h_act(C_t) - H_t = H - C_t = C - h_list.append(_op.expand_dims(H, axis=0)) + #H_t = H + h_list.append(_op.expand_dims(H_t, axis=0)) + #h_list.append(H_t) if backwards: # Canonical view is hidden states from the first token not last @@ -2430,8 +2429,11 @@ def _impl_v7(cls, inputs, attr, params): beta = betas[beta_loc] beta_loc += 1 acts.append(cls._activation_helper(activation, alpha, beta)) + f_act, g_act, h_act = acts else: - acts = [_op.sigmoid, _op.tanh, _op.tanh] * num_directions + f_act = _op.sigmoid + g_act = _op.tanh + h_act = _op.tanh X_steps = _op.split(X, indices_or_sections=X_shape[0], axis=0) result_output = [] @@ -2451,19 +2453,20 @@ def _impl_v7(cls, inputs, attr, params): C_t = _op.squeeze(C_ts[i], axis=[0]) W = _op.squeeze(Ws[i], axis=[0]) R = _op.squeeze(Rs[i], axis=[0]) - B = _op.squeeze(Bs[i], axis=[0]) + #B = _op.squeeze(Bs[i], axis=[0]) + WB, RB = _op.split(Bs[i], 2, -1) p_i = _op.squeeze(p_is[i], axis=[0]) p_f = _op.squeeze(p_fs[i], axis=[0]) p_o = _op.squeeze(p_os[i], axis=[0]) - f_act, g_act, h_act = acts[i * 3 : (i + 1) * 3] - output, H, C = LSTM.generate_lstm( + output, H, C = LSTM_dev.generate_lstm( X_steps=X_steps, H_t=H_t, C_t=C_t, W=W, R=R, - B=B, + WB=WB, + RB=RB, p_i=p_i, p_f=p_f, p_o=p_o, @@ -3700,7 +3703,7 @@ def _get_convert_map(opset): "Flatten": Flatten.get_converter(opset), "LRN": LRN.get_converter(opset), # Recurrent Layers - "LSTM": LSTM.get_converter(opset), + "LSTM": LSTM_dev.get_converter(opset), "GRU": GRU.get_converter(opset), # defs/vision "MaxRoiPool": MaxRoiPool.get_converter(opset), From 79bef55f84800e3752ab36465a3aa222c88f9db3 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Mon, 26 Jul 2021 17:47:13 +0300 Subject: [PATCH 05/21] LSTM_dev2 was implemented in onnx frontend --- python/tvm/relay/frontend/onnx.py | 254 +++++++++++++++++++++++++++++- 1 file changed, 250 insertions(+), 4 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 9d6250303194..5252f103fd9c 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -2351,9 +2351,7 @@ def generate_lstm( H_t = o * h_act(C_t) - #H_t = H h_list.append(_op.expand_dims(H_t, axis=0)) - #h_list.append(H_t) if backwards: # Canonical view is hidden states from the first token not last @@ -2453,7 +2451,6 @@ def _impl_v7(cls, inputs, attr, params): C_t = _op.squeeze(C_ts[i], axis=[0]) W = _op.squeeze(Ws[i], axis=[0]) R = _op.squeeze(Rs[i], axis=[0]) - #B = _op.squeeze(Bs[i], axis=[0]) WB, RB = _op.split(Bs[i], 2, -1) p_i = _op.squeeze(p_is[i], axis=[0]) p_f = _op.squeeze(p_fs[i], axis=[0]) @@ -2487,6 +2484,255 @@ def _impl_v7(cls, inputs, attr, params): return _expr.TupleWrapper(_expr.Tuple((output, H, C)), 3) +class LSTM_dev2(RNN): + """Operator converter for LSTM""" + + # TODO (vvchernov): unbind was gotten from pytorch.py and modified. + # It looks like torch.unbind + # It needs such operation on relay side to avoid excess manipulation like squeeze + @classmethod + def unbind(cls, data, axis=0): + shape = infer_shape(data) + selections = shape[axis] + res_split = _op.split(data, selections, axis) + ret = [] + for i in range(selections): + ret.append(_op.squeeze(res_split[i], axis=[axis])) + ret = _expr.TupleWrapper(_expr.Tuple(ret), selections) + return ret + + @classmethod + def lstm_cell( + cls, + input_seqs, + H_t, + C_t, + Wi, + Wh, + Bi, + Bh, + P, + p_i, + p_f, + p_o, + f_act, + g_act, + h_act, + backwards=False, + ): + # Input hidden state shape = (batch, hidden_size) + # Wi, Wh, Bi, Bh, proj matrix P, peephole matrices: p_i, p_f, p_o are expected. + # Wi and Wh shoud exist the others can be None + + outputs_list = [] + for x_t in input_seqs if not backwards else reversed(input_seqs): + # x_t shape = (batch, feature size), step shape = (batch, feature size + hidden_size) + step = _op.concatenate([x_t, H_t], axis=1) + W = _op.concatenate([Wi, Wh], axis=1) + # Instead of _op.nn.dense(x_t, weights[0]) + _op.nn.dense(H_t, weights[1]) we have _op.nn.dense(step, W) + # gates shape = (batch, 4 * hidden_size) + gates = _op.nn.dense(step, W) + # Add biases + if Bi is not None: + gates += Bi + if Bh is not None: + gates += Bh + i, f, c, o = _op.split(gates, 4, axis=-1) # (batch, hidden_size) + + if p_i is not None and p_f is not None: + i = f_act(i + p_i * C_t) + f = f_act(f + p_f * C_t) + else: + i = f_act(i) + f = f_act(f) + + c = g_act(c) + C_t = f * C_t + i * c + if p_o is not None: + o = f_act(o + p_o * C_t) + else: + o = f_act(o) + + H_t = o * h_act(C_t) + + if P is not None: + H_t = _op.nn.dense(H_t, P) + + outputs_list.append(H_t) # [seq_num, (batch, hidden_size)] + + return outputs_list, H_t, C_t + + @classmethod + def bidir_lstm_cell( + cls, + input_seqs, + weight_dicts, + f_act, + g_act, + h_act, + ): + seq_len = len(input_seqs) + forward_outputs, fw_H_t, fw_C_t = LSTM_dev2.lstm_cell( + input_seqs, + **weight_dicts[0], + f_act=f_act, + g_act=g_act, + h_act=h_act, + ) + + reverse_outputs, rev_H_t, rev_C_t = LSTM_dev2.lstm_cell( + input_seqs, + **weight_dicts[1], + f_act=f_act, + g_act=g_act, + h_act=h_act, + backwards=True, + ) + + final_outputs = [] + for i in range(seq_len): + final_outputs.append( + _op.stack([forward_outputs[i], reverse_outputs[seq_len - 1 - i]], axis=0) + ) + + return ( + _op.stack(final_outputs, axis=0), + _op.stack([fw_H_t, rev_H_t], axis=0), + _op.stack([fw_C_t, rev_C_t], axis=0), + ) + + @classmethod + def _impl_v7(cls, inputs, attr, params): + # Unpack inputs, note that if optional and not provided then value will be None. + X = inputs[0] + Wp = inputs[1] + Rp = inputs[2] + Bp = inputs[3] + # Sequence length currently unused as it can be inferred from shapes. + # sequence_lens = inputs['sequence_lens'] + Hp_0 = inputs[5] + Cp_0 = inputs[6] + Pp = inputs[7] + + num_directions = infer_shape(Wp)[0] + W_dtype = infer_type(Wp).checked_type.dtype + + if num_directions not in [1, 2]: + raise ValueError("num_directions must be either 1 or 2!") + + X_shape = infer_shape(X) + hidden_size = infer_shape(Rp)[-1] + batch_size = X_shape[1] + + # Initialize state if not provided. + # Otherwise remove bidirectional axis. + if Hp_0 is None: + Hp_0 = _op.zeros((num_directions, batch_size, hidden_size), W_dtype) + if Cp_0 is None: + Cp_0 = _op.zeros((num_directions, batch_size, hidden_size), W_dtype) + + if "activations" in attr: + activations = attr["activations"] + if len(activations) != 3 * num_directions: + raise NotImplementedError( + f"LSTM assumes 3 * num_directions activation functions are provided" + ) + alpha_loc = 0 + alphas = attr.get("activation_alpha", []) + if isinstance(alphas, float): + alphas = [alphas] + beta_loc = 0 + betas = attr.get("activation_beta", []) + if isinstance(betas, float): + betas = [betas] + acts = [] + for i in range(3 * num_directions): + alpha = None + beta = None + activation = activations[i] + if cls._activation_needs_alpha(activation) and len(alphas) > alpha_loc: + alpha = alphas[alpha_loc] + alpha_loc += 1 + if cls._activation_needs_beta(activation) and len(betas) > beta_loc: + beta = betas[beta_loc] + beta_loc += 1 + acts.append(cls._activation_helper(activation, alpha, beta)) + f_act, g_act, h_act = acts + else: + f_act = _op.sigmoid + g_act = _op.tanh + h_act = _op.tanh + + # It can be replaced by _op.split if issue #8412 is resolved + X_steps = LSTM_dev2.unbind(X, axis=0) + + H_ts = _op.split(Hp_0, num_directions) + C_ts = _op.split(Cp_0, num_directions) + Ws = _op.split(Wp, num_directions) + Rs = _op.split(Rp, num_directions) + + if Bp is not None: + Bs = _op.split(Bp, num_directions) + if Pp is not None: + p_i, p_o, p_f = _op.split(Pp, 3, axis=1) + + p_is = _op.split(p_i, num_directions) + p_fs = _op.split(p_f, num_directions) + p_os = _op.split(p_o, num_directions) + + weights_dicts = [] + for i in range(num_directions): + weights_dict = {} + + weights_dict["H_t"] = _op.squeeze(H_ts[i], axis=[0]) + weights_dict["C_t"] = _op.squeeze(C_ts[i], axis=[0]) + + weights_dict["Wi"] = _op.squeeze(Ws[i], axis=[0]) + weights_dict["Wh"] = _op.squeeze(Rs[i], axis=[0]) + if Bp is None: + Bi = None + Bh = None + else: + Bi, Bh = _op.split(Bs[i], 2, -1) + weights_dict["Bi"] = Bi + weights_dict["Bh"] = Bh + weights_dict["P"] = None + if Pp is None: + weights_dict["p_i"] = None + weights_dict["p_f"] = None + weights_dict["p_o"] = None + else: + weights_dict["p_i"] = _op.squeeze(p_is[i], axis=[0]) + weights_dict["p_f"] = _op.squeeze(p_fs[i], axis=[0]) + weights_dict["p_o"] = _op.squeeze(p_os[i], axis=[0]) + weights_dicts.append(weights_dict) + + if num_directions == 2: + output, H, C = LSTM_dev2.bidir_lstm_cell( + input_seqs=X_steps, + weight_dicts=weights_dicts, + f_act=f_act, + g_act=g_act, + h_act=h_act, + ) + else: + # outputs shape = [seqs_num, (batch_size, hidden_size)] + outputs, H, C = LSTM_dev2.lstm_cell( + input_seqs=X_steps, + **weights_dicts[0], + f_act=f_act, + g_act=g_act, + h_act=h_act, + ) + + # output shape = (seqs_num, num_directions, batch_size, hidden_size) + output = _op.expand_dims(_op.stack(outputs, axis=0), axis=1) + H = _op.expand_dims(H, axis=1) + C = _op.expand_dims(C, axis=1) + + return _expr.TupleWrapper(_expr.Tuple((output, H, C)), 3) + + class GRU(RNN): """Operator convert for GRU""" @@ -3703,7 +3949,7 @@ def _get_convert_map(opset): "Flatten": Flatten.get_converter(opset), "LRN": LRN.get_converter(opset), # Recurrent Layers - "LSTM": LSTM_dev.get_converter(opset), + "LSTM": LSTM_dev2.get_converter(opset), "GRU": GRU.get_converter(opset), # defs/vision "MaxRoiPool": MaxRoiPool.get_converter(opset), From e113728064d1ffc7e54ea27968929bc9d57af7b7 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Thu, 29 Jul 2021 12:36:11 +0300 Subject: [PATCH 06/21] LSTM dev in pytorch frontend --- python/tvm/relay/frontend/pytorch.py | 256 +++++++++++++++++++-------- 1 file changed, 180 insertions(+), 76 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 9e73950abf57..cc0fe45106db 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -2330,6 +2330,91 @@ def flip(self, inputs, input_types): axis = inputs[1] return _op.transform.reverse(data, axis=axis[0]) + def lstm_cell_dev( + self, + input_seqs, + H_t, + C_t, + Wi, + Wh, + Bi=None, + Bh=None, + P=None, + p_i=None, + p_f=None, + p_o=None, + f_act=_op.sigmoid, + g_act=_op.tanh, + h_act=_op.tanh, + backwards=False, + ): + # Input hidden state shape = (batch, hidden_size) + # Wi, Wh, Bi, Bh, proj matrix P, peephole matrices: p_i, p_f, p_o are expected. + # Wi and Wh shoud exist the others can be None + + outputs_list = [] + for x_t in input_seqs if not backwards else reversed(input_seqs): + # x_t shape = (batch, feature size), step shape = (batch, feature size + hidden_size) + step = _op.concatenate([x_t, H_t], axis=1) + W = _op.concatenate([Wi, Wh], axis=1) + # Instead of _op.nn.dense(x_t, weights[0]) + _op.nn.dense(H_t, weights[1]) we have _op.nn.dense(step, W) + # gates shape = (batch, 4 * hidden_size) + gates = _op.nn.dense(step, W) + # Add biases + if Bi is not None: + gates += Bi + if Bh is not None: + gates += Bh + i, f, c, o = _op.split(gates, 4, axis=-1) # (batch, hidden_size) + + if p_i is not None and p_f is not None: + i = f_act(i + p_i * C_t) + f = f_act(f + p_f * C_t) + else: + i = f_act(i) + f = f_act(f) + + c = g_act(c) + C_t = f * C_t + i * c + if p_o is not None: + o = f_act(o + p_o * C_t) + else: + o = f_act(o) + + H_t = o * h_act(C_t) + + if P is not None: + H_t = _op.nn.dense(H_t, P) + + outputs_list.append(H_t) # [seq_num, (batch, hidden_size)] + + return outputs_list, H_t, C_t + + def bidir_lstm_cell_dev( + self, + input_seqs, + weights_dicts, + ): + seq_len = len(input_seqs) + forward_outputs, fw_H_t, fw_C_t = self.lstm_cell_dev( + input_seqs, + **weights_dicts[0], + ) + + reverse_outputs, rev_H_t, rev_C_t = self.lstm_cell_dev( + input_seqs, + **weights_dicts[1], + backwards=True, + ) + + final_outputs = [] + for i in range(seq_len): + final_outputs.append( + _op.concatenate([forward_outputs[i], reverse_outputs[seq_len - 1 - i]], axis=-1) + ) + + return final_outputs, (fw_H_t, fw_C_t), (rev_H_t, rev_C_t) + def lstm_cell(self, input_seqs, hidden, weights, has_proj=False): if has_proj: assert len(weights) == 5 @@ -2347,7 +2432,7 @@ def lstm_cell(self, input_seqs, hidden, weights, has_proj=False): for x_t in input_seqs: # x_t shape = (batch, feature size), step shape = (batch, feature size + hidden_size) step = _op.concatenate([x_t, H_t], axis=1) - W = _op.concatenate([weights[0], weights[1]], axis = 1) + W = _op.concatenate([weights[0], weights[1]], axis=1) # Instead of _op.nn.dense(x_t, weights[0]) + _op.nn.dense(H_t, weights[1]) we have _op.nn.dense(step, W) # gates shape = (batch, 4 * hidden_size) gates = _op.nn.dense(step, W) @@ -2392,41 +2477,37 @@ def bidir_lstm_cell(self, input_seq, hidden_pair, weights_pair, has_proj=False): return final_outputs, (fw_outputs[1], rev_outputs[1]) def lstm_layers( - self, input_data, hiddens, weights, bidirectional, dtype, dropout_p=0.0, has_proj=False + self, input_data, layer_weights_dicts, bidirectional, dtype, dropout_p=0.0 ): - hidden_layers_num = len(hiddens) - assert len(weights) == hidden_layers_num - + layers_num = len(layer_weights_dicts) # split input sequence to samples set input_seqs = self.unbind((input_data, 0), dtype) # [seq_num, (batch, feature_size)] output_hiddens = [] - for k in range(hidden_layers_num): - hiddens_input = hiddens[k] - weights_input = weights[k] - - outputs = ( - self.bidir_lstm_cell(input_seqs, hiddens_input, weights_input, has_proj) - if bidirectional - else self.lstm_cell(input_seqs, hiddens_input, weights_input, has_proj) - ) - - output_hiddens.append(outputs[1]) + for i in range(layers_num): + weights_dicts=layer_weights_dicts[i] # input_seqs shape = [seq_num, (batch, feature_size)] or # [seq_num, (batch, 2*feature_size)] for bidirectional - input_seqs = outputs[0] + if bidirectional: + input_seqs, H_t, C_t = self.bidir_lstm_cell_dev( + input_seqs, weights_dicts + ) + else: + input_seqs, H_t, C_t = self.lstm_cell_dev(input_seqs, **weights_dicts[0]) + + output_hiddens.append((H_t, C_t)) # TODO (vvchernov): in pytorch implementation train is also checked # see https://github.com/pytorch/pytorch/blob/70c8daf43946b53af6493d058899ef952d27d339 # /aten/src/ATen/native/RNN.cpp#L1054 - if dropout_p != 0 and k < hidden_layers_num - 1: + if dropout_p != 0 and i < layers_num - 1: # for input in input_seqs: # input = _op.dropout(input, dropout_p) raise NotImplementedError("Dropout for LSTM has not been supported yet!") final_hiddens = [] if bidirectional: - for i in range(hidden_layers_num): - final_hiddens.append(output_hiddens[i][0]) - final_hiddens.append(output_hiddens[i][1]) + for output_hidden in output_hiddens: + final_hiddens.append(output_hidden[0]) + final_hiddens.append(output_hidden[1]) else: final_hiddens = output_hiddens @@ -2514,52 +2595,6 @@ def lstm(self, inputs, input_types): else: assert weights_num == 2, "The weights number in layer is expected equal to 2" - weights = [] - if has_biases: - if bidirectional: - rsd = len(_weights) % (2 * weights_num) - assert rsd == 0, "got an incorrect number of LSTM weights" - for i in range(0, len(_weights), 2 * weights_num): - fw_weights = [] - rev_weights = [] - for j in range(weights_num): - fw_weights.append(_weights[i + j]) - rev_weights.append(_weights[i + j + weights_num]) - weights.append((fw_weights, rev_weights)) - else: - assert len(_weights) % weights_num == 0, "got an incorrect number of LSTM weights" - for i in range(0, len(_weights), weights_num): - fw_weights = [] - for j in range(weights_num): - fw_weights.append(_weights[i + j]) - weights.append(fw_weights) - else: - if bidirectional: - rsd = len(_weights) % (2 * weights_num) - assert rsd == 0, "got an incorrect number of LSTM weights" - for i in range(0, len(_weights), 2 * weights_num): - fw_weights = [] - rev_weights = [] - k = i + weights_num - if has_proj: - fw_weights = [_weights[i], _weights[i + 1], None, None, _weights[i + 2]] - rev_weights = [_weights[k], _weights[k + 1], None, None, _weights[k + 2]] - else: - fw_weights = [_weights[i], _weights[i + 1], None, None] - rev_weights = [_weights[k], _weights[k + 1], None, None] - weights.append((fw_weights, rev_weights)) - else: - assert len(_weights) % weights_num == 0, "got an incorrect number of LSTM weights" - for i in range(0, len(_weights), weights_num): - if has_proj: - fw_weights = [_weights[i], _weights[i + 1], None, None, _weights[i + 2]] - else: - fw_weights = [_weights[i], _weights[i + 1], None, None] - weights.append(fw_weights) - assert ( - len(weights) == num_layers - ), "For stacked LSTM number of weights tuples should be the same as number of layers!" - X = _op.transpose(_X, (1, 0, 2)) if batch_first else _X # TODO (vvchernov): Which data type should be used? from input or weights? # Instead of it _infer_type(X).checked_type.dtype can be used @@ -2589,23 +2624,92 @@ def lstm(self, inputs, input_types): else: layers_c = self.unbind((c_0, 0), X_dtype) - hiddens = [] - for i in range(num_layers): + layer_weights_dicts = [] + k=0 # layer counter + if has_biases: if bidirectional: - hiddens.append( - ((layers_h[2 * i], layers_c[2 * i]), (layers_h[2 * i + 1], layers_c[2 * i + 1])) - ) + rsd = len(_weights) % (2 * weights_num) + assert rsd == 0, "got an incorrect number of LSTM weights" + for i in range(0, len(_weights), 2 * weights_num): + fw_weights_dict = {} + fw_weights_dict["H_t"]=layers_h[2*k] + fw_weights_dict["C_t"]=layers_c[2*k] + fw_weights_dict["Wi"]=_weights[i] + fw_weights_dict["Wh"]=_weights[i+1] + fw_weights_dict["Bi"]=_weights[i+2] + fw_weights_dict["Bh"]=_weights[i+3] + if has_proj: + fw_weights_dict["P"]=_weights[i+4] + rev_weights_dict = {} + j=i+weights_num + rev_weights_dict["H_t"]=layers_h[2*k+1] + rev_weights_dict["C_t"]=layers_c[2*k+1] + rev_weights_dict["Wi"]=_weights[j] + rev_weights_dict["Wh"]=_weights[j+1] + rev_weights_dict["Bi"]=_weights[j+2] + rev_weights_dict["Bh"]=_weights[j+3] + if has_proj: + rev_weights_dict["P"]=_weights[j+4] + layer_weights_dicts.append([fw_weights_dict, rev_weights_dict]) + k+=1 else: - hiddens.append((layers_h[i], layers_c[i])) + assert len(_weights) % weights_num == 0, "got an incorrect number of LSTM weights" + for i in range(0, len(_weights), weights_num): + fw_weights_dict = {} + fw_weights_dict["H_t"]=layers_h[k] + fw_weights_dict["C_t"]=layers_c[k] + fw_weights_dict["Wi"]=_weights[i] + fw_weights_dict["Wh"]=_weights[i+1] + fw_weights_dict["Bi"]=_weights[i+2] + fw_weights_dict["Bh"]=_weights[i+3] + if has_proj: + fw_weights_dict["P"]=_weights[i+4] + layer_weights_dicts.append([fw_weights_dict]) + k+=1 + else: + if bidirectional: + rsd = len(_weights) % (2 * weights_num) + assert rsd == 0, "got an incorrect number of LSTM weights" + for i in range(0, len(_weights), 2 * weights_num): + fw_weights_dict = {} + fw_weights_dict["H_t"]=layers_h[2*k] + fw_weights_dict["C_t"]=layers_c[2*k] + fw_weights_dict["Wi"]=_weights[i] + fw_weights_dict["Wh"]=_weights[i+1] + if has_proj: + fw_weights_dict["P"]=_weights[i+2] + rev_weights_dict = {} + j=i+weights_num + rev_weights_dict["H_t"]=layers_h[2*k+1] + rev_weights_dict["C_t"]=layers_c[2*k+1] + rev_weights_dict["Wi"]=_weights[j] + rev_weights_dict["Wh"]=_weights[j+1] + if has_proj: + rev_weights_dict["P"]=_weights[j+2] + layer_weights_dicts.append([fw_weights_dict, rev_weights_dict]) + k+=1 + else: + assert len(_weights) % weights_num == 0, "got an incorrect number of LSTM weights" + for i in range(0, len(_weights), weights_num): + fw_weights_dict = {} + fw_weights_dict["H_t"]=layers_h[k] + fw_weights_dict["C_t"]=layers_c[k] + fw_weights_dict["Wi"]=_weights[i] + fw_weights_dict["Wh"]=_weights[i+1] + if has_proj: + fw_weights_dict["P"]=_weights[i+2] + layer_weights_dicts.append([fw_weights_dict]) + k+=1 + assert ( + len(layer_weights_dicts) == num_layers and k == num_layers + ), "For stacked LSTM number of weights sets should be the same as number of layers!" outputs = self.lstm_layers( X, - hiddens, - weights, + layer_weights_dicts, bidirectional, dtype=X_dtype, dropout_p=dropout_p, - has_proj=has_proj, ) # output shape = (seq_num, batch, hidden_size) or From f2860ce76f6752d7a7def21c06a7f8036e593c6e Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Fri, 30 Jul 2021 10:41:42 +0300 Subject: [PATCH 07/21] LSTM cell implementation was transferred to common place. Unneccessary code was removed --- python/tvm/relay/frontend/onnx.py | 436 +-------------------------- python/tvm/relay/frontend/pytorch.py | 131 +------- python/tvm/relay/op/__init__.py | 1 + python/tvm/relay/op/layers.py | 82 +++++ 4 files changed, 99 insertions(+), 551 deletions(-) create mode 100644 python/tvm/relay/op/layers.py diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 5252f103fd9c..f29308459a67 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -2142,351 +2142,6 @@ def _activation_needs_beta(cls, activation): class LSTM(RNN): """Operator converter for LSTM""" - @classmethod - def generate_lstm( - cls, X_steps, H_t, C_t, W, R, B, p_i, p_f, p_o, f_act, g_act, h_act, backwards=False - ): - """Create an unrolled lstm loop. - - See https://github.com/onnx/onnx/blob/master/docs/Operators.md for math. - """ - h_list = [] - seq_length = len(X_steps) - for i in range(seq_length): - step = X_steps[i] if not backwards else X_steps[seq_length - (i + 1)] - step = _op.squeeze(step, axis=[0]) - gates = _op.nn.dense(step, W) + _op.nn.dense(H_t, R) - if B is not None: - WB, RB = _op.split(B, 2) - gates += WB + RB - i, o, f, c = _op.split(gates, 4, axis=-1) - - if p_i != 0: - i = f_act(i + p_i * C_t) - else: - i = f_act(i) - - if p_f != 0: - f = f_act(f + p_f * C_t) - else: - f = f_act(f) - - c = g_act(c) - C = f * C_t + i * c - if p_o != 0: - o = f_act(o + p_o * C) - else: - o = f_act(o) - - H = o * h_act(C) - - H_t = H - C_t = C - h_list.append(_op.expand_dims(H, axis=0)) - - if backwards: - # Canonical view is hidden states from the first token not last - h_list = h_list[::-1] - - # Concatenate outputs and add back in direction axis. - concatenated = _op.concatenate(h_list, 0) - output = _op.expand_dims(concatenated, axis=1) - H_t = _op.expand_dims(H_t, axis=0) - C_t = _op.expand_dims(C_t, axis=0) - - return output, H_t, C_t - - @classmethod - def _impl_v7(cls, inputs, attr, params): - # Unpack inputs, note that if optional and not provided then value will be None. - X = inputs[0] - Wp = inputs[1] - Rp = inputs[2] - Bp = inputs[3] - # Sequence length currently unused as it can be inferred from shapes. - # sequence_lens = inputs['sequence_lens'] - Hp_0 = inputs[5] - Cp_0 = inputs[6] - Pp = inputs[7] - - num_directions = infer_shape(Wp)[0] - W_dtype = infer_type(Wp).checked_type.dtype - - if num_directions not in [1, 2]: - raise ValueError("num_directions must be either 1 or 2!") - - X_shape = infer_shape(X) - hidden_size = infer_shape(Rp)[-1] - batch_size = X_shape[1] - - # Initialize state if not provided. - # Otherwise remove bidirectional axis. - if Hp_0 is None: - Hp_0 = _op.zeros((num_directions, batch_size, hidden_size), W_dtype) - if Cp_0 is None: - Cp_0 = _op.zeros((num_directions, batch_size, hidden_size), W_dtype) - if Bp is None: - Bp = _op.zeros((num_directions, hidden_size * 8), W_dtype) - if Pp is not None: - p_i, p_o, p_f = _op.split(Pp, 3, axis=1) - else: - p_i = p_o = p_f = _op.zeros((num_directions, hidden_size), W_dtype) - - if "activations" in attr: - activations = attr["activations"] - if len(activations) != 3 * num_directions: - raise NotImplementedError( - f"LSTM assumes 3 * num_directions activation functions are provided" - ) - alpha_loc = 0 - alphas = attr.get("activation_alpha", []) - if isinstance(alphas, float): - alphas = [alphas] - beta_loc = 0 - betas = attr.get("activation_beta", []) - if isinstance(betas, float): - betas = [betas] - acts = [] - for i in range(3 * num_directions): - alpha = None - beta = None - activation = activations[i] - if cls._activation_needs_alpha(activation) and len(alphas) > alpha_loc: - alpha = alphas[alpha_loc] - alpha_loc += 1 - if cls._activation_needs_beta(activation) and len(betas) > beta_loc: - beta = betas[beta_loc] - beta_loc += 1 - acts.append(cls._activation_helper(activation, alpha, beta)) - else: - acts = [_op.sigmoid, _op.tanh, _op.tanh] * num_directions - - X_steps = _op.split(X, indices_or_sections=X_shape[0], axis=0) - result_output = [] - result_H = [] - result_C = [] - - H_ts = _op.split(Hp_0, num_directions) - C_ts = _op.split(Cp_0, num_directions) - Ws = _op.split(Wp, num_directions) - Rs = _op.split(Rp, num_directions) - Bs = _op.split(Bp, num_directions) - p_is = _op.split(p_i, num_directions) - p_fs = _op.split(p_f, num_directions) - p_os = _op.split(p_o, num_directions) - for i in range(num_directions): - H_t = _op.squeeze(H_ts[i], axis=[0]) - C_t = _op.squeeze(C_ts[i], axis=[0]) - W = _op.squeeze(Ws[i], axis=[0]) - R = _op.squeeze(Rs[i], axis=[0]) - B = _op.squeeze(Bs[i], axis=[0]) - p_i = _op.squeeze(p_is[i], axis=[0]) - p_f = _op.squeeze(p_fs[i], axis=[0]) - p_o = _op.squeeze(p_os[i], axis=[0]) - - f_act, g_act, h_act = acts[i * 3 : (i + 1) * 3] - output, H, C = LSTM.generate_lstm( - X_steps=X_steps, - H_t=H_t, - C_t=C_t, - W=W, - R=R, - B=B, - p_i=p_i, - p_f=p_f, - p_o=p_o, - f_act=f_act, - g_act=g_act, - h_act=h_act, - backwards=i == 1, - ) - - result_output.append(output) - result_H.append(H) - result_C.append(C) - - output = _op.concatenate(result_output, axis=1) - H = _op.concatenate(result_H, axis=0) - C = _op.concatenate(result_C, axis=0) - - return _expr.TupleWrapper(_expr.Tuple((output, H, C)), 3) - - -class LSTM_dev(RNN): - """Operator converter for LSTM""" - - @classmethod - def generate_lstm( - cls, X_steps, H_t, C_t, W, R, WB, RB, p_i, p_f, p_o, f_act, g_act, h_act, backwards=False - ): - """Create an unrolled lstm loop. - - See https://github.com/onnx/onnx/blob/master/docs/Operators.md for math. - """ - h_list = [] - seq_length = len(X_steps) - for i in range(seq_length): - step = X_steps[i] if not backwards else X_steps[seq_length - (i + 1)] - step = _op.squeeze(step, axis=[0]) - gates = _op.nn.dense(step, W) + _op.nn.dense(H_t, R) - if WB is not None and RB is not None: - gates += WB + RB - i, o, f, c = _op.split(gates, 4, axis=-1) - if p_i != 0: - i = f_act(i + p_i * C_t) - else: - i = f_act(i) - - if p_f != 0: - f = f_act(f + p_f * C_t) - else: - f = f_act(f) - - c = g_act(c) - C_t = f * C_t + i * c - if p_o != 0: - o = f_act(o + p_o * C_t) - else: - o = f_act(o) - - H_t = o * h_act(C_t) - - h_list.append(_op.expand_dims(H_t, axis=0)) - - if backwards: - # Canonical view is hidden states from the first token not last - h_list = h_list[::-1] - - # Concatenate outputs and add back in direction axis. - concatenated = _op.concatenate(h_list, 0) - output = _op.expand_dims(concatenated, axis=1) - H_t = _op.expand_dims(H_t, axis=0) - C_t = _op.expand_dims(C_t, axis=0) - - return output, H_t, C_t - - @classmethod - def _impl_v7(cls, inputs, attr, params): - # Unpack inputs, note that if optional and not provided then value will be None. - X = inputs[0] - Wp = inputs[1] - Rp = inputs[2] - Bp = inputs[3] - # Sequence length currently unused as it can be inferred from shapes. - # sequence_lens = inputs['sequence_lens'] - Hp_0 = inputs[5] - Cp_0 = inputs[6] - Pp = inputs[7] - - num_directions = infer_shape(Wp)[0] - W_dtype = infer_type(Wp).checked_type.dtype - - if num_directions not in [1, 2]: - raise ValueError("num_directions must be either 1 or 2!") - - X_shape = infer_shape(X) - hidden_size = infer_shape(Rp)[-1] - batch_size = X_shape[1] - - # Initialize state if not provided. - # Otherwise remove bidirectional axis. - if Hp_0 is None: - Hp_0 = _op.zeros((num_directions, batch_size, hidden_size), W_dtype) - if Cp_0 is None: - Cp_0 = _op.zeros((num_directions, batch_size, hidden_size), W_dtype) - if Bp is None: - Bp = _op.zeros((num_directions, hidden_size * 8), W_dtype) - if Pp is not None: - p_i, p_o, p_f = _op.split(Pp, 3, axis=1) - else: - p_i = p_o = p_f = _op.zeros((num_directions, hidden_size), W_dtype) - - if "activations" in attr: - activations = attr["activations"] - if len(activations) != 3 * num_directions: - raise NotImplementedError( - f"LSTM assumes 3 * num_directions activation functions are provided" - ) - alpha_loc = 0 - alphas = attr.get("activation_alpha", []) - if isinstance(alphas, float): - alphas = [alphas] - beta_loc = 0 - betas = attr.get("activation_beta", []) - if isinstance(betas, float): - betas = [betas] - acts = [] - for i in range(3 * num_directions): - alpha = None - beta = None - activation = activations[i] - if cls._activation_needs_alpha(activation) and len(alphas) > alpha_loc: - alpha = alphas[alpha_loc] - alpha_loc += 1 - if cls._activation_needs_beta(activation) and len(betas) > beta_loc: - beta = betas[beta_loc] - beta_loc += 1 - acts.append(cls._activation_helper(activation, alpha, beta)) - f_act, g_act, h_act = acts - else: - f_act = _op.sigmoid - g_act = _op.tanh - h_act = _op.tanh - - X_steps = _op.split(X, indices_or_sections=X_shape[0], axis=0) - result_output = [] - result_H = [] - result_C = [] - - H_ts = _op.split(Hp_0, num_directions) - C_ts = _op.split(Cp_0, num_directions) - Ws = _op.split(Wp, num_directions) - Rs = _op.split(Rp, num_directions) - Bs = _op.split(Bp, num_directions) - p_is = _op.split(p_i, num_directions) - p_fs = _op.split(p_f, num_directions) - p_os = _op.split(p_o, num_directions) - for i in range(num_directions): - H_t = _op.squeeze(H_ts[i], axis=[0]) - C_t = _op.squeeze(C_ts[i], axis=[0]) - W = _op.squeeze(Ws[i], axis=[0]) - R = _op.squeeze(Rs[i], axis=[0]) - WB, RB = _op.split(Bs[i], 2, -1) - p_i = _op.squeeze(p_is[i], axis=[0]) - p_f = _op.squeeze(p_fs[i], axis=[0]) - p_o = _op.squeeze(p_os[i], axis=[0]) - - output, H, C = LSTM_dev.generate_lstm( - X_steps=X_steps, - H_t=H_t, - C_t=C_t, - W=W, - R=R, - WB=WB, - RB=RB, - p_i=p_i, - p_f=p_f, - p_o=p_o, - f_act=f_act, - g_act=g_act, - h_act=h_act, - backwards=i == 1, - ) - - result_output.append(output) - result_H.append(H) - result_C.append(C) - - output = _op.concatenate(result_output, axis=1) - H = _op.concatenate(result_H, axis=0) - C = _op.concatenate(result_C, axis=0) - - return _expr.TupleWrapper(_expr.Tuple((output, H, C)), 3) - - -class LSTM_dev2(RNN): - """Operator converter for LSTM""" - # TODO (vvchernov): unbind was gotten from pytorch.py and modified. # It looks like torch.unbind # It needs such operation on relay side to avoid excess manipulation like squeeze @@ -2501,67 +2156,6 @@ def unbind(cls, data, axis=0): ret = _expr.TupleWrapper(_expr.Tuple(ret), selections) return ret - @classmethod - def lstm_cell( - cls, - input_seqs, - H_t, - C_t, - Wi, - Wh, - Bi, - Bh, - P, - p_i, - p_f, - p_o, - f_act, - g_act, - h_act, - backwards=False, - ): - # Input hidden state shape = (batch, hidden_size) - # Wi, Wh, Bi, Bh, proj matrix P, peephole matrices: p_i, p_f, p_o are expected. - # Wi and Wh shoud exist the others can be None - - outputs_list = [] - for x_t in input_seqs if not backwards else reversed(input_seqs): - # x_t shape = (batch, feature size), step shape = (batch, feature size + hidden_size) - step = _op.concatenate([x_t, H_t], axis=1) - W = _op.concatenate([Wi, Wh], axis=1) - # Instead of _op.nn.dense(x_t, weights[0]) + _op.nn.dense(H_t, weights[1]) we have _op.nn.dense(step, W) - # gates shape = (batch, 4 * hidden_size) - gates = _op.nn.dense(step, W) - # Add biases - if Bi is not None: - gates += Bi - if Bh is not None: - gates += Bh - i, f, c, o = _op.split(gates, 4, axis=-1) # (batch, hidden_size) - - if p_i is not None and p_f is not None: - i = f_act(i + p_i * C_t) - f = f_act(f + p_f * C_t) - else: - i = f_act(i) - f = f_act(f) - - c = g_act(c) - C_t = f * C_t + i * c - if p_o is not None: - o = f_act(o + p_o * C_t) - else: - o = f_act(o) - - H_t = o * h_act(C_t) - - if P is not None: - H_t = _op.nn.dense(H_t, P) - - outputs_list.append(H_t) # [seq_num, (batch, hidden_size)] - - return outputs_list, H_t, C_t - @classmethod def bidir_lstm_cell( cls, @@ -2572,7 +2166,7 @@ def bidir_lstm_cell( h_act, ): seq_len = len(input_seqs) - forward_outputs, fw_H_t, fw_C_t = LSTM_dev2.lstm_cell( + forward_outputs, fw_H_t, fw_C_t = _op.lstm_cell( input_seqs, **weight_dicts[0], f_act=f_act, @@ -2580,7 +2174,7 @@ def bidir_lstm_cell( h_act=h_act, ) - reverse_outputs, rev_H_t, rev_C_t = LSTM_dev2.lstm_cell( + reverse_outputs, rev_H_t, rev_C_t = _op.lstm_cell( input_seqs, **weight_dicts[1], f_act=f_act, @@ -2663,8 +2257,8 @@ def _impl_v7(cls, inputs, attr, params): g_act = _op.tanh h_act = _op.tanh - # It can be replaced by _op.split if issue #8412 is resolved - X_steps = LSTM_dev2.unbind(X, axis=0) + # TODO (vvchernov): It can be replaced by _op.split if issue #8412 is resolved + X_steps = LSTM.unbind(X, axis=0) H_ts = _op.split(Hp_0, num_directions) C_ts = _op.split(Cp_0, num_directions) @@ -2689,26 +2283,18 @@ def _impl_v7(cls, inputs, attr, params): weights_dict["Wi"] = _op.squeeze(Ws[i], axis=[0]) weights_dict["Wh"] = _op.squeeze(Rs[i], axis=[0]) - if Bp is None: - Bi = None - Bh = None - else: + if Bp is not None: Bi, Bh = _op.split(Bs[i], 2, -1) - weights_dict["Bi"] = Bi - weights_dict["Bh"] = Bh - weights_dict["P"] = None - if Pp is None: - weights_dict["p_i"] = None - weights_dict["p_f"] = None - weights_dict["p_o"] = None - else: + weights_dict["Bi"] = Bi + weights_dict["Bh"] = Bh + if Pp is not None: weights_dict["p_i"] = _op.squeeze(p_is[i], axis=[0]) weights_dict["p_f"] = _op.squeeze(p_fs[i], axis=[0]) weights_dict["p_o"] = _op.squeeze(p_os[i], axis=[0]) weights_dicts.append(weights_dict) if num_directions == 2: - output, H, C = LSTM_dev2.bidir_lstm_cell( + output, H, C = LSTM.bidir_lstm_cell( input_seqs=X_steps, weight_dicts=weights_dicts, f_act=f_act, @@ -2717,7 +2303,7 @@ def _impl_v7(cls, inputs, attr, params): ) else: # outputs shape = [seqs_num, (batch_size, hidden_size)] - outputs, H, C = LSTM_dev2.lstm_cell( + outputs, H, C = _op.lstm_cell( input_seqs=X_steps, **weights_dicts[0], f_act=f_act, @@ -3949,7 +3535,7 @@ def _get_convert_map(opset): "Flatten": Flatten.get_converter(opset), "LRN": LRN.get_converter(opset), # Recurrent Layers - "LSTM": LSTM_dev2.get_converter(opset), + "LSTM": LSTM.get_converter(opset), "GRU": GRU.get_converter(opset), # defs/vision "MaxRoiPool": MaxRoiPool.get_converter(opset), diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index cc0fe45106db..036588c96645 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -2330,78 +2330,18 @@ def flip(self, inputs, input_types): axis = inputs[1] return _op.transform.reverse(data, axis=axis[0]) - def lstm_cell_dev( - self, - input_seqs, - H_t, - C_t, - Wi, - Wh, - Bi=None, - Bh=None, - P=None, - p_i=None, - p_f=None, - p_o=None, - f_act=_op.sigmoid, - g_act=_op.tanh, - h_act=_op.tanh, - backwards=False, - ): - # Input hidden state shape = (batch, hidden_size) - # Wi, Wh, Bi, Bh, proj matrix P, peephole matrices: p_i, p_f, p_o are expected. - # Wi and Wh shoud exist the others can be None - - outputs_list = [] - for x_t in input_seqs if not backwards else reversed(input_seqs): - # x_t shape = (batch, feature size), step shape = (batch, feature size + hidden_size) - step = _op.concatenate([x_t, H_t], axis=1) - W = _op.concatenate([Wi, Wh], axis=1) - # Instead of _op.nn.dense(x_t, weights[0]) + _op.nn.dense(H_t, weights[1]) we have _op.nn.dense(step, W) - # gates shape = (batch, 4 * hidden_size) - gates = _op.nn.dense(step, W) - # Add biases - if Bi is not None: - gates += Bi - if Bh is not None: - gates += Bh - i, f, c, o = _op.split(gates, 4, axis=-1) # (batch, hidden_size) - - if p_i is not None and p_f is not None: - i = f_act(i + p_i * C_t) - f = f_act(f + p_f * C_t) - else: - i = f_act(i) - f = f_act(f) - - c = g_act(c) - C_t = f * C_t + i * c - if p_o is not None: - o = f_act(o + p_o * C_t) - else: - o = f_act(o) - - H_t = o * h_act(C_t) - - if P is not None: - H_t = _op.nn.dense(H_t, P) - - outputs_list.append(H_t) # [seq_num, (batch, hidden_size)] - - return outputs_list, H_t, C_t - - def bidir_lstm_cell_dev( + def bidir_lstm_cell( self, input_seqs, weights_dicts, ): seq_len = len(input_seqs) - forward_outputs, fw_H_t, fw_C_t = self.lstm_cell_dev( + forward_outputs, fw_H_t, fw_C_t = _op.lstm_cell( input_seqs, **weights_dicts[0], ) - reverse_outputs, rev_H_t, rev_C_t = self.lstm_cell_dev( + reverse_outputs, rev_H_t, rev_C_t = _op.lstm_cell( input_seqs, **weights_dicts[1], backwards=True, @@ -2415,67 +2355,6 @@ def bidir_lstm_cell_dev( return final_outputs, (fw_H_t, fw_C_t), (rev_H_t, rev_C_t) - def lstm_cell(self, input_seqs, hidden, weights, has_proj=False): - if has_proj: - assert len(weights) == 5 - else: - assert len(weights) == 4 - outputs_list = [] - # Default activations types - f_act = _op.sigmoid - g_act = _op.tanh - h_act = _op.tanh - - # Input hiddens - H_t = hidden[0] # (batch, hidden_size) - C_t = hidden[1] # (batch, hidden_size) - for x_t in input_seqs: - # x_t shape = (batch, feature size), step shape = (batch, feature size + hidden_size) - step = _op.concatenate([x_t, H_t], axis=1) - W = _op.concatenate([weights[0], weights[1]], axis=1) - # Instead of _op.nn.dense(x_t, weights[0]) + _op.nn.dense(H_t, weights[1]) we have _op.nn.dense(step, W) - # gates shape = (batch, 4 * hidden_size) - gates = _op.nn.dense(step, W) - # Add biases - if weights[2] is not None: - gates += weights[2] - if weights[3] is not None: - gates += weights[3] - i, f, c, o = _op.split(gates, 4, axis=-1) # (batch, hidden_size) - - i = f_act(i) - f = f_act(f) - c = g_act(c) - o = f_act(o) - - C_t = f * C_t + i * c - H_t = o * h_act(C_t) - - if has_proj: - H_t = _op.nn.dense(H_t, weights[4]) - - outputs_list.append(H_t) # [seq_num, (batch, hidden_size)] - hidden_outputs = (H_t, C_t) - - return (outputs_list, hidden_outputs) - - def bidir_lstm_cell(self, input_seq, hidden_pair, weights_pair, has_proj=False): - fw_outputs = self.lstm_cell(input_seq, hidden_pair[0], weights_pair[0], has_proj) - - rev_input_seq = [] - seq_len = len(input_seq) - for i in range(seq_len): - rev_input_seq.append(input_seq[seq_len - 1 - i]) # [seq_num, (batch, hidden_size)] - rev_outputs = self.lstm_cell(rev_input_seq, hidden_pair[1], weights_pair[1], has_proj) - - final_outputs = [] # [seq_num, (batch, 2 * hidden_size)] - for j in range(seq_len): - final_outputs.append( - _op.concatenate([fw_outputs[0][j], rev_outputs[0][seq_len - 1 - j]], -1) - ) - - return final_outputs, (fw_outputs[1], rev_outputs[1]) - def lstm_layers( self, input_data, layer_weights_dicts, bidirectional, dtype, dropout_p=0.0 ): @@ -2488,11 +2367,11 @@ def lstm_layers( # input_seqs shape = [seq_num, (batch, feature_size)] or # [seq_num, (batch, 2*feature_size)] for bidirectional if bidirectional: - input_seqs, H_t, C_t = self.bidir_lstm_cell_dev( + input_seqs, H_t, C_t = self.bidir_lstm_cell( input_seqs, weights_dicts ) else: - input_seqs, H_t, C_t = self.lstm_cell_dev(input_seqs, **weights_dicts[0]) + input_seqs, H_t, C_t = _op.lstm_cell(input_seqs, **weights_dicts[0]) output_hiddens.append((H_t, C_t)) diff --git a/python/tvm/relay/op/__init__.py b/python/tvm/relay/op/__init__.py index 2e509a111c4a..7412042b00ca 100644 --- a/python/tvm/relay/op/__init__.py +++ b/python/tvm/relay/op/__init__.py @@ -38,6 +38,7 @@ from .tensor import * from .transform import * from .algorithm import * +from .layers import * from . import vm from . import nn from . import annotation diff --git a/python/tvm/relay/op/layers.py b/python/tvm/relay/op/layers.py new file mode 100644 index 000000000000..b386ede4ba48 --- /dev/null +++ b/python/tvm/relay/op/layers.py @@ -0,0 +1,82 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Intermediate representation of complicated layers unified for all frontends""" + +from .tensor import sigmoid, tanh, concatenate +from .transform import split +from . import nn + + +def lstm_cell( + input_seqs, + H_t, + C_t, + Wi, + Wh, + Bi=None, + Bh=None, + P=None, + p_i=None, + p_f=None, + p_o=None, + f_act=sigmoid, + g_act=tanh, + h_act=tanh, + backwards=False, +): + # Input hidden state shape = (batch, hidden_size) + # Wi, Wh, Bi, Bh, proj matrix P, peephole matrices: p_i, p_f, p_o are expected. + # Wi and Wh shoud exist the others can be None + + outputs_list = [] + for x_t in input_seqs if not backwards else reversed(input_seqs): + # x_t shape = (batch, feature size), step shape = (batch, feature size + hidden_size) + step = concatenate([x_t, H_t], axis=1) + W = concatenate([Wi, Wh], axis=1) + # Instead of nn.dense(x_t, weights[0]) + nn.dense(H_t, weights[1]) we have nn.dense(step, W) + # gates shape = (batch, 4 * hidden_size) + gates = nn.dense(step, W) + # Add biases + if Bi is not None: + gates += Bi + if Bh is not None: + gates += Bh + i, f, c, o = split(gates, 4, axis=-1) # (batch, hidden_size) + + if p_i is not None and p_f is not None: + i = f_act(i + p_i * C_t) + f = f_act(f + p_f * C_t) + else: + i = f_act(i) + f = f_act(f) + + c = g_act(c) + C_t = f * C_t + i * c + if p_o is not None: + o = f_act(o + p_o * C_t) + else: + o = f_act(o) + + H_t = o * h_act(C_t) + + if P is not None: + H_t = nn.dense(H_t, P) + + outputs_list.append(H_t) # [seq_num, (batch, hidden_size)] + + return outputs_list, H_t, C_t From 8db8acfae1c9422d4f518c343668f201896139d5 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Fri, 30 Jul 2021 13:29:51 +0300 Subject: [PATCH 08/21] lint fixes --- python/tvm/contrib/target/onnx.py | 2 +- python/tvm/relay/frontend/onnx.py | 13 ++-- python/tvm/relay/frontend/pytorch.py | 96 +++++++++++++--------------- python/tvm/relay/op/layers.py | 62 +++++++++--------- 4 files changed, 84 insertions(+), 89 deletions(-) diff --git a/python/tvm/contrib/target/onnx.py b/python/tvm/contrib/target/onnx.py index b839af669fe6..6f8aab23cde1 100644 --- a/python/tvm/contrib/target/onnx.py +++ b/python/tvm/contrib/target/onnx.py @@ -655,7 +655,7 @@ def convert_attributes(cls, attrs): class Cast(OpConverter): - """ Operator converter for Cast.""" + """Operator converter for Cast.""" @classmethod def convert_attributes(cls, attrs): diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index f29308459a67..88310e6bd173 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -23,7 +23,6 @@ import numpy as np import tvm from tvm.ir import IRModule -from tvm.relay.op.tensor import concatenate from tvm.topi.utils import get_const_tuple from ... import nd as _nd @@ -2278,15 +2277,15 @@ def _impl_v7(cls, inputs, attr, params): for i in range(num_directions): weights_dict = {} - weights_dict["H_t"] = _op.squeeze(H_ts[i], axis=[0]) - weights_dict["C_t"] = _op.squeeze(C_ts[i], axis=[0]) + weights_dict["ht"] = _op.squeeze(H_ts[i], axis=[0]) + weights_dict["ct"] = _op.squeeze(C_ts[i], axis=[0]) - weights_dict["Wi"] = _op.squeeze(Ws[i], axis=[0]) - weights_dict["Wh"] = _op.squeeze(Rs[i], axis=[0]) + weights_dict["wi"] = _op.squeeze(Ws[i], axis=[0]) + weights_dict["wh"] = _op.squeeze(Rs[i], axis=[0]) if Bp is not None: Bi, Bh = _op.split(Bs[i], 2, -1) - weights_dict["Bi"] = Bi - weights_dict["Bh"] = Bh + weights_dict["bi"] = Bi + weights_dict["bh"] = Bh if Pp is not None: weights_dict["p_i"] = _op.squeeze(p_is[i], axis=[0]) weights_dict["p_f"] = _op.squeeze(p_fs[i], axis=[0]) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 036588c96645..b463937cf3e8 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -2355,21 +2355,17 @@ def bidir_lstm_cell( return final_outputs, (fw_H_t, fw_C_t), (rev_H_t, rev_C_t) - def lstm_layers( - self, input_data, layer_weights_dicts, bidirectional, dtype, dropout_p=0.0 - ): + def lstm_layers(self, input_data, layer_weights_dicts, bidirectional, dtype, dropout_p=0.0): layers_num = len(layer_weights_dicts) # split input sequence to samples set input_seqs = self.unbind((input_data, 0), dtype) # [seq_num, (batch, feature_size)] output_hiddens = [] for i in range(layers_num): - weights_dicts=layer_weights_dicts[i] + weights_dicts = layer_weights_dicts[i] # input_seqs shape = [seq_num, (batch, feature_size)] or # [seq_num, (batch, 2*feature_size)] for bidirectional if bidirectional: - input_seqs, H_t, C_t = self.bidir_lstm_cell( - input_seqs, weights_dicts - ) + input_seqs, H_t, C_t = self.bidir_lstm_cell(input_seqs, weights_dicts) else: input_seqs, H_t, C_t = _op.lstm_cell(input_seqs, **weights_dicts[0]) @@ -2504,81 +2500,81 @@ def lstm(self, inputs, input_types): layers_c = self.unbind((c_0, 0), X_dtype) layer_weights_dicts = [] - k=0 # layer counter + k = 0 # layer counter if has_biases: if bidirectional: rsd = len(_weights) % (2 * weights_num) assert rsd == 0, "got an incorrect number of LSTM weights" for i in range(0, len(_weights), 2 * weights_num): fw_weights_dict = {} - fw_weights_dict["H_t"]=layers_h[2*k] - fw_weights_dict["C_t"]=layers_c[2*k] - fw_weights_dict["Wi"]=_weights[i] - fw_weights_dict["Wh"]=_weights[i+1] - fw_weights_dict["Bi"]=_weights[i+2] - fw_weights_dict["Bh"]=_weights[i+3] + fw_weights_dict["ht"] = layers_h[2 * k] + fw_weights_dict["ct"] = layers_c[2 * k] + fw_weights_dict["wi"] = _weights[i] + fw_weights_dict["wh"] = _weights[i + 1] + fw_weights_dict["bi"] = _weights[i + 2] + fw_weights_dict["bh"] = _weights[i + 3] if has_proj: - fw_weights_dict["P"]=_weights[i+4] + fw_weights_dict["P"] = _weights[i + 4] rev_weights_dict = {} - j=i+weights_num - rev_weights_dict["H_t"]=layers_h[2*k+1] - rev_weights_dict["C_t"]=layers_c[2*k+1] - rev_weights_dict["Wi"]=_weights[j] - rev_weights_dict["Wh"]=_weights[j+1] - rev_weights_dict["Bi"]=_weights[j+2] - rev_weights_dict["Bh"]=_weights[j+3] + j = i + weights_num + rev_weights_dict["ht"] = layers_h[2 * k + 1] + rev_weights_dict["ct"] = layers_c[2 * k + 1] + rev_weights_dict["wi"] = _weights[j] + rev_weights_dict["wh"] = _weights[j + 1] + rev_weights_dict["bi"] = _weights[j + 2] + rev_weights_dict["bh"] = _weights[j + 3] if has_proj: - rev_weights_dict["P"]=_weights[j+4] + rev_weights_dict["p"] = _weights[j + 4] layer_weights_dicts.append([fw_weights_dict, rev_weights_dict]) - k+=1 + k += 1 else: assert len(_weights) % weights_num == 0, "got an incorrect number of LSTM weights" for i in range(0, len(_weights), weights_num): fw_weights_dict = {} - fw_weights_dict["H_t"]=layers_h[k] - fw_weights_dict["C_t"]=layers_c[k] - fw_weights_dict["Wi"]=_weights[i] - fw_weights_dict["Wh"]=_weights[i+1] - fw_weights_dict["Bi"]=_weights[i+2] - fw_weights_dict["Bh"]=_weights[i+3] + fw_weights_dict["ht"] = layers_h[k] + fw_weights_dict["ct"] = layers_c[k] + fw_weights_dict["wi"] = _weights[i] + fw_weights_dict["wh"] = _weights[i + 1] + fw_weights_dict["bi"] = _weights[i + 2] + fw_weights_dict["bh"] = _weights[i + 3] if has_proj: - fw_weights_dict["P"]=_weights[i+4] + fw_weights_dict["p"] = _weights[i + 4] layer_weights_dicts.append([fw_weights_dict]) - k+=1 + k += 1 else: if bidirectional: rsd = len(_weights) % (2 * weights_num) assert rsd == 0, "got an incorrect number of LSTM weights" for i in range(0, len(_weights), 2 * weights_num): fw_weights_dict = {} - fw_weights_dict["H_t"]=layers_h[2*k] - fw_weights_dict["C_t"]=layers_c[2*k] - fw_weights_dict["Wi"]=_weights[i] - fw_weights_dict["Wh"]=_weights[i+1] + fw_weights_dict["ht"] = layers_h[2 * k] + fw_weights_dict["ct"] = layers_c[2 * k] + fw_weights_dict["wi"] = _weights[i] + fw_weights_dict["wh"] = _weights[i + 1] if has_proj: - fw_weights_dict["P"]=_weights[i+2] + fw_weights_dict["p"] = _weights[i + 2] rev_weights_dict = {} - j=i+weights_num - rev_weights_dict["H_t"]=layers_h[2*k+1] - rev_weights_dict["C_t"]=layers_c[2*k+1] - rev_weights_dict["Wi"]=_weights[j] - rev_weights_dict["Wh"]=_weights[j+1] + j = i + weights_num + rev_weights_dict["ht"] = layers_h[2 * k + 1] + rev_weights_dict["ct"] = layers_c[2 * k + 1] + rev_weights_dict["wi"] = _weights[j] + rev_weights_dict["wh"] = _weights[j + 1] if has_proj: - rev_weights_dict["P"]=_weights[j+2] + rev_weights_dict["P"] = _weights[j + 2] layer_weights_dicts.append([fw_weights_dict, rev_weights_dict]) - k+=1 + k += 1 else: assert len(_weights) % weights_num == 0, "got an incorrect number of LSTM weights" for i in range(0, len(_weights), weights_num): fw_weights_dict = {} - fw_weights_dict["H_t"]=layers_h[k] - fw_weights_dict["C_t"]=layers_c[k] - fw_weights_dict["Wi"]=_weights[i] - fw_weights_dict["Wh"]=_weights[i+1] + fw_weights_dict["ht"] = layers_h[k] + fw_weights_dict["ct"] = layers_c[k] + fw_weights_dict["wi"] = _weights[i] + fw_weights_dict["wh"] = _weights[i + 1] if has_proj: - fw_weights_dict["P"]=_weights[i+2] + fw_weights_dict["p"] = _weights[i + 2] layer_weights_dicts.append([fw_weights_dict]) - k+=1 + k += 1 assert ( len(layer_weights_dicts) == num_layers and k == num_layers ), "For stacked LSTM number of weights sets should be the same as number of layers!" diff --git a/python/tvm/relay/op/layers.py b/python/tvm/relay/op/layers.py index b386ede4ba48..01dcda10ad4a 100644 --- a/python/tvm/relay/op/layers.py +++ b/python/tvm/relay/op/layers.py @@ -24,13 +24,13 @@ def lstm_cell( input_seqs, - H_t, - C_t, - Wi, - Wh, - Bi=None, - Bh=None, - P=None, + ht, + ct, + wi, + wh, + bi=None, + bh=None, + p=None, p_i=None, p_f=None, p_o=None, @@ -40,43 +40,43 @@ def lstm_cell( backwards=False, ): # Input hidden state shape = (batch, hidden_size) - # Wi, Wh, Bi, Bh, proj matrix P, peephole matrices: p_i, p_f, p_o are expected. - # Wi and Wh shoud exist the others can be None + # wi, wh, bi, bh, proj matrix (p), peephole matrices: p_i, p_f, p_o are expected. + # wi and wh shoud exist the others can be None outputs_list = [] for x_t in input_seqs if not backwards else reversed(input_seqs): # x_t shape = (batch, feature size), step shape = (batch, feature size + hidden_size) - step = concatenate([x_t, H_t], axis=1) - W = concatenate([Wi, Wh], axis=1) - # Instead of nn.dense(x_t, weights[0]) + nn.dense(H_t, weights[1]) we have nn.dense(step, W) + step = concatenate([x_t, ht], axis=1) + w = concatenate([wi, wh], axis=1) + # Instead of nn.dense(x_t, weights[0]) + nn.dense(ht, weights[1]) we have nn.dense(step, W) # gates shape = (batch, 4 * hidden_size) - gates = nn.dense(step, W) + gates = nn.dense(step, w) # Add biases - if Bi is not None: - gates += Bi - if Bh is not None: - gates += Bh - i, f, c, o = split(gates, 4, axis=-1) # (batch, hidden_size) + if bi is not None: + gates += bi + if bh is not None: + gates += bh + ig, fg, cg, og = split(gates, 4, axis=-1) # (batch, hidden_size) if p_i is not None and p_f is not None: - i = f_act(i + p_i * C_t) - f = f_act(f + p_f * C_t) + ig = f_act(ig + p_i * ct) + fg = f_act(fg + p_f * ct) else: - i = f_act(i) - f = f_act(f) + ig = f_act(ig) + fg = f_act(fg) - c = g_act(c) - C_t = f * C_t + i * c + cg = g_act(cg) + ct = fg * ct + ig * cg if p_o is not None: - o = f_act(o + p_o * C_t) + og = f_act(og + p_o * ct) else: - o = f_act(o) + og = f_act(og) - H_t = o * h_act(C_t) + ht = og * h_act(ct) - if P is not None: - H_t = nn.dense(H_t, P) + if p is not None: + ht = nn.dense(ht, p) - outputs_list.append(H_t) # [seq_num, (batch, hidden_size)] + outputs_list.append(ht) # [seq_num, (batch, hidden_size)] - return outputs_list, H_t, C_t + return outputs_list, ht, ct From 0e99c90c6d613e991976cb008456d2c8ce31bd36 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Fri, 30 Jul 2021 21:13:00 +0300 Subject: [PATCH 09/21] Weights permutation for LSTM layer in onnx frontend --- python/tvm/relay/frontend/onnx.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 88310e6bd173..75d546e44351 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -2280,12 +2280,17 @@ def _impl_v7(cls, inputs, attr, params): weights_dict["ht"] = _op.squeeze(H_ts[i], axis=[0]) weights_dict["ct"] = _op.squeeze(C_ts[i], axis=[0]) - weights_dict["wi"] = _op.squeeze(Ws[i], axis=[0]) - weights_dict["wh"] = _op.squeeze(Rs[i], axis=[0]) + # Weights permutation: onnx format i-o-f-c, lstm cell format i-f-c-o + mati, mato, matf, matc = _op.split(_op.squeeze(Ws[i], axis=[0]), 4) + weights_dict["wi"] = _op.concatenate([mati, matf, matc, mato], axis=0) + mati, mato, matf, matc = _op.split(_op.squeeze(Rs[i], axis=[0]), 4) + weights_dict["wh"] = _op.concatenate([mati, matf, matc, mato], axis=0) if Bp is not None: Bi, Bh = _op.split(Bs[i], 2, -1) - weights_dict["bi"] = Bi - weights_dict["bh"] = Bh + mati, mato, matf, matc = _op.split(_op.squeeze(Bi, axis=[0]), 4) + weights_dict["bi"] = _op.concatenate([mati, matf, matc, mato], axis=0) + mati, mato, matf, matc = _op.split(_op.squeeze(Bh, axis=[0]), 4) + weights_dict["bh"] = _op.concatenate([mati, matf, matc, mato], axis=0) if Pp is not None: weights_dict["p_i"] = _op.squeeze(p_is[i], axis=[0]) weights_dict["p_f"] = _op.squeeze(p_fs[i], axis=[0]) From f5d31d3d75820fe5257bd14f4876d644fa81be7b Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Fri, 30 Jul 2021 21:46:04 +0300 Subject: [PATCH 10/21] LSTM cell description was added --- python/tvm/relay/op/layers.py | 38 ++++++++++++++++++++++++++++++++--- 1 file changed, 35 insertions(+), 3 deletions(-) diff --git a/python/tvm/relay/op/layers.py b/python/tvm/relay/op/layers.py index 01dcda10ad4a..c4b41a06ff12 100644 --- a/python/tvm/relay/op/layers.py +++ b/python/tvm/relay/op/layers.py @@ -39,9 +39,41 @@ def lstm_cell( h_act=tanh, backwards=False, ): - # Input hidden state shape = (batch, hidden_size) - # wi, wh, bi, bh, proj matrix (p), peephole matrices: p_i, p_f, p_o are expected. - # wi and wh shoud exist the others can be None + """ + Common implementation of LSTM cell for all frontends of TVM + TODO (vvchernov): currently it is used by onnx and pytorch. + + Parameters + ---------- + input_seqs : List[relay.Expr] + The sequence of input tensors + Input tensor should be 2d while issue #8412 is not resolved + Shape = (batch, feature_size) + ht : relay.Expr + Hidden state. shape = (batch, hidden_size) + ct : relay.Expr + Cell state. shape = (batch, hidden_size) + wi, wh : relay.Expr + weight matrices. wi shape = (4 * hidden_size, feature_size) + wh shape = (4 * hidden_size, hidden_size or proj_size) + NOTE: wi = (w_ii|w_if|w_ig|w_io) for input, forget, cell and output gates. + The order is important for correct LSTM calculation! + bi, bh : relay.Expr + bias matrices. The same order of internal parts as for weights. shape = (4 * hidden_size) + p : relay.Expr + projection matrix. shape = (proj_size, hidden_size) + p_i, p_f, p_o : relay.Expr + peephole LSTM matrices. shape = (batch, hidden_size) + f_act, g_act, h_act : relay.op + activation funtions + backwards : bool + Flag for reverse pass of LSTM + + Returns + ------- + result : List[relay.Expr], relay.Expr, relay.Expr + The sequence of computed result, final hidden and cell state + """ outputs_list = [] for x_t in input_seqs if not backwards else reversed(input_seqs): From 0c5044b4438d7c3f4dd355f0fa8b7a70067360a6 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Sat, 31 Jul 2021 16:52:42 +0300 Subject: [PATCH 11/21] arguments and values were renamed. descriptions of some methods were added --- python/tvm/relay/frontend/onnx.py | 35 +++++++++---- python/tvm/relay/frontend/pytorch.py | 78 +++++++++++++++------------- python/tvm/relay/op/layers.py | 69 ++++++++++++------------ 3 files changed, 103 insertions(+), 79 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 75d546e44351..d7d1e4813b5c 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -2141,11 +2141,25 @@ def _activation_needs_beta(cls, activation): class LSTM(RNN): """Operator converter for LSTM""" - # TODO (vvchernov): unbind was gotten from pytorch.py and modified. - # It looks like torch.unbind - # It needs such operation on relay side to avoid excess manipulation like squeeze @classmethod def unbind(cls, data, axis=0): + """ + Unbind was gotten from pytorch.py and modified. The operation removes a tensor dimension + and returns a tuple of all slices along a given dimension, already without it. + TODO (vvchernov): It needs such operation on relay side to reduce time consumption + on squeeze operation. + + Parameters + ---------- + data : relay.Expr + Input tensor + axis : int + Axis along which tensor is splited. Tensors in the list has not this axis. + Returns + ------- + result : List[relay.Expr] + The sequence of computed tensors + """ shape = infer_shape(data) selections = shape[axis] res_split = _op.split(data, selections, axis) @@ -2164,6 +2178,9 @@ def bidir_lstm_cell( g_act, h_act, ): + """ + Bidirectional LSTM cell + """ seq_len = len(input_seqs) forward_outputs, fw_H_t, fw_C_t = _op.lstm_cell( input_seqs, @@ -2277,20 +2294,20 @@ def _impl_v7(cls, inputs, attr, params): for i in range(num_directions): weights_dict = {} - weights_dict["ht"] = _op.squeeze(H_ts[i], axis=[0]) - weights_dict["ct"] = _op.squeeze(C_ts[i], axis=[0]) + weights_dict["hidden_state"] = _op.squeeze(H_ts[i], axis=[0]) + weights_dict["cell_state"] = _op.squeeze(C_ts[i], axis=[0]) # Weights permutation: onnx format i-o-f-c, lstm cell format i-f-c-o mati, mato, matf, matc = _op.split(_op.squeeze(Ws[i], axis=[0]), 4) - weights_dict["wi"] = _op.concatenate([mati, matf, matc, mato], axis=0) + weights_dict["w_inp"] = _op.concatenate([mati, matf, matc, mato], axis=0) mati, mato, matf, matc = _op.split(_op.squeeze(Rs[i], axis=[0]), 4) - weights_dict["wh"] = _op.concatenate([mati, matf, matc, mato], axis=0) + weights_dict["w_hid"] = _op.concatenate([mati, matf, matc, mato], axis=0) if Bp is not None: Bi, Bh = _op.split(Bs[i], 2, -1) mati, mato, matf, matc = _op.split(_op.squeeze(Bi, axis=[0]), 4) - weights_dict["bi"] = _op.concatenate([mati, matf, matc, mato], axis=0) + weights_dict["b_inp"] = _op.concatenate([mati, matf, matc, mato], axis=0) mati, mato, matf, matc = _op.split(_op.squeeze(Bh, axis=[0]), 4) - weights_dict["bh"] = _op.concatenate([mati, matf, matc, mato], axis=0) + weights_dict["b_hid"] = _op.concatenate([mati, matf, matc, mato], axis=0) if Pp is not None: weights_dict["p_i"] = _op.squeeze(p_is[i], axis=[0]) weights_dict["p_f"] = _op.squeeze(p_fs[i], axis=[0]) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index b463937cf3e8..be62d920f9ff 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -2335,6 +2335,9 @@ def bidir_lstm_cell( input_seqs, weights_dicts, ): + """ + Bidirectional LSTM cell + """ seq_len = len(input_seqs) forward_outputs, fw_H_t, fw_C_t = _op.lstm_cell( input_seqs, @@ -2356,6 +2359,9 @@ def bidir_lstm_cell( return final_outputs, (fw_H_t, fw_C_t), (rev_H_t, rev_C_t) def lstm_layers(self, input_data, layer_weights_dicts, bidirectional, dtype, dropout_p=0.0): + """ + Methods iterates layers for Stacked LSTM + """ layers_num = len(layer_weights_dicts) # split input sequence to samples set input_seqs = self.unbind((input_data, 0), dtype) # [seq_num, (batch, feature_size)] @@ -2507,38 +2513,38 @@ def lstm(self, inputs, input_types): assert rsd == 0, "got an incorrect number of LSTM weights" for i in range(0, len(_weights), 2 * weights_num): fw_weights_dict = {} - fw_weights_dict["ht"] = layers_h[2 * k] - fw_weights_dict["ct"] = layers_c[2 * k] - fw_weights_dict["wi"] = _weights[i] - fw_weights_dict["wh"] = _weights[i + 1] - fw_weights_dict["bi"] = _weights[i + 2] - fw_weights_dict["bh"] = _weights[i + 3] + fw_weights_dict["hidden_state"] = layers_h[2 * k] + fw_weights_dict["cell_state"] = layers_c[2 * k] + fw_weights_dict["w_inp"] = _weights[i] + fw_weights_dict["w_hid"] = _weights[i + 1] + fw_weights_dict["b_inp"] = _weights[i + 2] + fw_weights_dict["b_hid"] = _weights[i + 3] if has_proj: - fw_weights_dict["P"] = _weights[i + 4] + fw_weights_dict["proj"] = _weights[i + 4] rev_weights_dict = {} j = i + weights_num - rev_weights_dict["ht"] = layers_h[2 * k + 1] - rev_weights_dict["ct"] = layers_c[2 * k + 1] - rev_weights_dict["wi"] = _weights[j] - rev_weights_dict["wh"] = _weights[j + 1] - rev_weights_dict["bi"] = _weights[j + 2] - rev_weights_dict["bh"] = _weights[j + 3] + rev_weights_dict["hidden_state"] = layers_h[2 * k + 1] + rev_weights_dict["cell_state"] = layers_c[2 * k + 1] + rev_weights_dict["w_inp"] = _weights[j] + rev_weights_dict["w_hid"] = _weights[j + 1] + rev_weights_dict["b_inp"] = _weights[j + 2] + rev_weights_dict["b_hid"] = _weights[j + 3] if has_proj: - rev_weights_dict["p"] = _weights[j + 4] + rev_weights_dict["proj"] = _weights[j + 4] layer_weights_dicts.append([fw_weights_dict, rev_weights_dict]) k += 1 else: assert len(_weights) % weights_num == 0, "got an incorrect number of LSTM weights" for i in range(0, len(_weights), weights_num): fw_weights_dict = {} - fw_weights_dict["ht"] = layers_h[k] - fw_weights_dict["ct"] = layers_c[k] - fw_weights_dict["wi"] = _weights[i] - fw_weights_dict["wh"] = _weights[i + 1] - fw_weights_dict["bi"] = _weights[i + 2] - fw_weights_dict["bh"] = _weights[i + 3] + fw_weights_dict["hidden_state"] = layers_h[k] + fw_weights_dict["cell_state"] = layers_c[k] + fw_weights_dict["w_inp"] = _weights[i] + fw_weights_dict["w_hid"] = _weights[i + 1] + fw_weights_dict["b_inp"] = _weights[i + 2] + fw_weights_dict["b_hid"] = _weights[i + 3] if has_proj: - fw_weights_dict["p"] = _weights[i + 4] + fw_weights_dict["proj"] = _weights[i + 4] layer_weights_dicts.append([fw_weights_dict]) k += 1 else: @@ -2547,32 +2553,32 @@ def lstm(self, inputs, input_types): assert rsd == 0, "got an incorrect number of LSTM weights" for i in range(0, len(_weights), 2 * weights_num): fw_weights_dict = {} - fw_weights_dict["ht"] = layers_h[2 * k] - fw_weights_dict["ct"] = layers_c[2 * k] - fw_weights_dict["wi"] = _weights[i] - fw_weights_dict["wh"] = _weights[i + 1] + fw_weights_dict["hidden_state"] = layers_h[2 * k] + fw_weights_dict["cell_state"] = layers_c[2 * k] + fw_weights_dict["w_inp"] = _weights[i] + fw_weights_dict["w_hid"] = _weights[i + 1] if has_proj: - fw_weights_dict["p"] = _weights[i + 2] + fw_weights_dict["proj"] = _weights[i + 2] rev_weights_dict = {} j = i + weights_num - rev_weights_dict["ht"] = layers_h[2 * k + 1] - rev_weights_dict["ct"] = layers_c[2 * k + 1] - rev_weights_dict["wi"] = _weights[j] - rev_weights_dict["wh"] = _weights[j + 1] + rev_weights_dict["hidden_state"] = layers_h[2 * k + 1] + rev_weights_dict["cell_state"] = layers_c[2 * k + 1] + rev_weights_dict["w_inp"] = _weights[j] + rev_weights_dict["w_hid"] = _weights[j + 1] if has_proj: - rev_weights_dict["P"] = _weights[j + 2] + rev_weights_dict["proj"] = _weights[j + 2] layer_weights_dicts.append([fw_weights_dict, rev_weights_dict]) k += 1 else: assert len(_weights) % weights_num == 0, "got an incorrect number of LSTM weights" for i in range(0, len(_weights), weights_num): fw_weights_dict = {} - fw_weights_dict["ht"] = layers_h[k] - fw_weights_dict["ct"] = layers_c[k] - fw_weights_dict["wi"] = _weights[i] - fw_weights_dict["wh"] = _weights[i + 1] + fw_weights_dict["hidden_state"] = layers_h[k] + fw_weights_dict["cell_state"] = layers_c[k] + fw_weights_dict["w_inp"] = _weights[i] + fw_weights_dict["w_hid"] = _weights[i + 1] if has_proj: - fw_weights_dict["p"] = _weights[i + 2] + fw_weights_dict["proj"] = _weights[i + 2] layer_weights_dicts.append([fw_weights_dict]) k += 1 assert ( diff --git a/python/tvm/relay/op/layers.py b/python/tvm/relay/op/layers.py index c4b41a06ff12..9acf223f8240 100644 --- a/python/tvm/relay/op/layers.py +++ b/python/tvm/relay/op/layers.py @@ -24,13 +24,13 @@ def lstm_cell( input_seqs, - ht, - ct, - wi, - wh, - bi=None, - bh=None, - p=None, + hidden_state, + cell_state, + w_inp, + w_hid, + b_inp=None, + b_hid=None, + proj=None, p_i=None, p_f=None, p_o=None, @@ -49,18 +49,18 @@ def lstm_cell( The sequence of input tensors Input tensor should be 2d while issue #8412 is not resolved Shape = (batch, feature_size) - ht : relay.Expr + hidden_state : relay.Expr Hidden state. shape = (batch, hidden_size) - ct : relay.Expr + cell_state : relay.Expr Cell state. shape = (batch, hidden_size) - wi, wh : relay.Expr + w_inp, w_hid : relay.Expr weight matrices. wi shape = (4 * hidden_size, feature_size) wh shape = (4 * hidden_size, hidden_size or proj_size) NOTE: wi = (w_ii|w_if|w_ig|w_io) for input, forget, cell and output gates. The order is important for correct LSTM calculation! - bi, bh : relay.Expr + b_inp, b_hid : relay.Expr bias matrices. The same order of internal parts as for weights. shape = (4 * hidden_size) - p : relay.Expr + proj : relay.Expr projection matrix. shape = (proj_size, hidden_size) p_i, p_f, p_o : relay.Expr peephole LSTM matrices. shape = (batch, hidden_size) @@ -78,37 +78,38 @@ def lstm_cell( outputs_list = [] for x_t in input_seqs if not backwards else reversed(input_seqs): # x_t shape = (batch, feature size), step shape = (batch, feature size + hidden_size) - step = concatenate([x_t, ht], axis=1) - w = concatenate([wi, wh], axis=1) - # Instead of nn.dense(x_t, weights[0]) + nn.dense(ht, weights[1]) we have nn.dense(step, W) + step = concatenate([x_t, hidden_state], axis=1) + cat_w = concatenate([w_inp, w_hid], axis=1) + # Instead of nn.dense(x_t, w_inp) + nn.dense(hidden_state, w_hid) + # the nn.dense(step, cat_w) is used # gates shape = (batch, 4 * hidden_size) - gates = nn.dense(step, w) + gates = nn.dense(step, cat_w) # Add biases - if bi is not None: - gates += bi - if bh is not None: - gates += bh - ig, fg, cg, og = split(gates, 4, axis=-1) # (batch, hidden_size) + if b_inp is not None: + gates += b_inp + if b_hid is not None: + gates += b_hid + inp_gate, fgt_gate, cell_gate, otp_gate = split(gates, 4, axis=-1) # (batch, hidden_size) if p_i is not None and p_f is not None: - ig = f_act(ig + p_i * ct) - fg = f_act(fg + p_f * ct) + inp_gate = f_act(inp_gate + p_i * cell_state) + fgt_gate = f_act(fgt_gate + p_f * cell_state) else: - ig = f_act(ig) - fg = f_act(fg) + inp_gate = f_act(inp_gate) + fgt_gate = f_act(fgt_gate) - cg = g_act(cg) - ct = fg * ct + ig * cg + cell_gate = g_act(cell_gate) + cell_state = fgt_gate * cell_state + inp_gate * cell_gate if p_o is not None: - og = f_act(og + p_o * ct) + otp_gate = f_act(otp_gate + p_o * cell_state) else: - og = f_act(og) + otp_gate = f_act(otp_gate) - ht = og * h_act(ct) + hidden_state = otp_gate * h_act(cell_state) - if p is not None: - ht = nn.dense(ht, p) + if proj is not None: + hidden_state = nn.dense(hidden_state, proj) - outputs_list.append(ht) # [seq_num, (batch, hidden_size)] + outputs_list.append(hidden_state) # [seq_num, (batch, hidden_size)] - return outputs_list, ht, ct + return outputs_list, hidden_state, cell_state From 6877742f5eb85dc93846d966f1176dd0f04a4629 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Mon, 2 Aug 2021 13:42:06 +0300 Subject: [PATCH 12/21] LSTM output shape and actvations input format were fixed in onnx frontend --- python/tvm/relay/frontend/onnx.py | 35 +++++++++++++------------------ 1 file changed, 14 insertions(+), 21 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index d7d1e4813b5c..fd73834f6853 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -2174,9 +2174,7 @@ def bidir_lstm_cell( cls, input_seqs, weight_dicts, - f_act, - g_act, - h_act, + acts, ): """ Bidirectional LSTM cell @@ -2185,17 +2183,17 @@ def bidir_lstm_cell( forward_outputs, fw_H_t, fw_C_t = _op.lstm_cell( input_seqs, **weight_dicts[0], - f_act=f_act, - g_act=g_act, - h_act=h_act, + f_act=acts[0], + g_act=acts[1], + h_act=acts[2], ) reverse_outputs, rev_H_t, rev_C_t = _op.lstm_cell( input_seqs, **weight_dicts[1], - f_act=f_act, - g_act=g_act, - h_act=h_act, + f_act=acts[3], + g_act=acts[4], + h_act=acts[5], backwards=True, ) @@ -2267,11 +2265,8 @@ def _impl_v7(cls, inputs, attr, params): beta = betas[beta_loc] beta_loc += 1 acts.append(cls._activation_helper(activation, alpha, beta)) - f_act, g_act, h_act = acts else: - f_act = _op.sigmoid - g_act = _op.tanh - h_act = _op.tanh + acts = [_op.sigmoid, _op.tanh, _op.tanh] * num_directions # TODO (vvchernov): It can be replaced by _op.split if issue #8412 is resolved X_steps = LSTM.unbind(X, axis=0) @@ -2318,24 +2313,22 @@ def _impl_v7(cls, inputs, attr, params): output, H, C = LSTM.bidir_lstm_cell( input_seqs=X_steps, weight_dicts=weights_dicts, - f_act=f_act, - g_act=g_act, - h_act=h_act, + acts=acts, ) else: # outputs shape = [seqs_num, (batch_size, hidden_size)] outputs, H, C = _op.lstm_cell( input_seqs=X_steps, **weights_dicts[0], - f_act=f_act, - g_act=g_act, - h_act=h_act, + f_act=acts[0], + g_act=acts[1], + h_act=acts[2], ) # output shape = (seqs_num, num_directions, batch_size, hidden_size) output = _op.expand_dims(_op.stack(outputs, axis=0), axis=1) - H = _op.expand_dims(H, axis=1) - C = _op.expand_dims(C, axis=1) + H = _op.expand_dims(H, axis=0) + C = _op.expand_dims(C, axis=0) return _expr.TupleWrapper(_expr.Tuple((output, H, C)), 3) From 845ebcfbd6f0fd57e672c4dcceb36376d301a898 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Mon, 2 Aug 2021 22:12:25 +0300 Subject: [PATCH 13/21] empty. tvm-ci test From d3afbc6054b8a4e33c3eb02279bc8d3419cb41d9 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Tue, 3 Aug 2021 15:19:59 +0300 Subject: [PATCH 14/21] unbind method was transferred from onnx frontend to common.py --- python/tvm/relay/frontend/common.py | 28 ++++++++++++++++++++++++++ python/tvm/relay/frontend/onnx.py | 31 ++--------------------------- 2 files changed, 30 insertions(+), 29 deletions(-) diff --git a/python/tvm/relay/frontend/common.py b/python/tvm/relay/frontend/common.py index 9c53b59f9998..debd7586952c 100755 --- a/python/tvm/relay/frontend/common.py +++ b/python/tvm/relay/frontend/common.py @@ -624,3 +624,31 @@ def to_int_list(np_array): cause problems in relay/TOPI. """ return [int(x) for x in np_array] + + +def unbind(data, axis=0): + """ + Unbind was gotten from pytorch.py and modified. The operation removes a tensor dimension + and returns a tuple of all slices along a given dimension, already without it. + TODO (vvchernov): It needs such operation on relay side to reduce time consumption + on squeeze operation. + + Parameters + ---------- + data : relay.Expr + Input tensor + axis : int + Axis along which tensor is splited. Tensors in the list do not have this axis. + Returns + ------- + result : List[relay.Expr] + The sequence of computed tensors + """ + shape = infer_shape(data) + selections = shape[axis] + res_split = _op.split(data, selections, axis) + ret = [] + for i in range(selections): + ret.append(_op.squeeze(res_split[i], axis=[axis])) + ret = _expr.TupleWrapper(_expr.Tuple(ret), selections) + return ret diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index fd73834f6853..8f863aacab9a 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -46,6 +46,7 @@ infer_type, infer_value, new_var, + unbind, ) __all__ = ["from_onnx"] @@ -2141,34 +2142,6 @@ def _activation_needs_beta(cls, activation): class LSTM(RNN): """Operator converter for LSTM""" - @classmethod - def unbind(cls, data, axis=0): - """ - Unbind was gotten from pytorch.py and modified. The operation removes a tensor dimension - and returns a tuple of all slices along a given dimension, already without it. - TODO (vvchernov): It needs such operation on relay side to reduce time consumption - on squeeze operation. - - Parameters - ---------- - data : relay.Expr - Input tensor - axis : int - Axis along which tensor is splited. Tensors in the list has not this axis. - Returns - ------- - result : List[relay.Expr] - The sequence of computed tensors - """ - shape = infer_shape(data) - selections = shape[axis] - res_split = _op.split(data, selections, axis) - ret = [] - for i in range(selections): - ret.append(_op.squeeze(res_split[i], axis=[axis])) - ret = _expr.TupleWrapper(_expr.Tuple(ret), selections) - return ret - @classmethod def bidir_lstm_cell( cls, @@ -2269,7 +2242,7 @@ def _impl_v7(cls, inputs, attr, params): acts = [_op.sigmoid, _op.tanh, _op.tanh] * num_directions # TODO (vvchernov): It can be replaced by _op.split if issue #8412 is resolved - X_steps = LSTM.unbind(X, axis=0) + X_steps = unbind(X, axis=0) H_ts = _op.split(Hp_0, num_directions) C_ts = _op.split(Cp_0, num_directions) From f81933300d3cc61d04e31146cb865b66650ede77 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Tue, 3 Aug 2021 15:39:18 +0300 Subject: [PATCH 15/21] unbind method was transferred from pytorch frontend to common.py --- python/tvm/relay/frontend/common.py | 4 ++++ python/tvm/relay/frontend/pytorch.py | 30 ++++++---------------------- 2 files changed, 10 insertions(+), 24 deletions(-) diff --git a/python/tvm/relay/frontend/common.py b/python/tvm/relay/frontend/common.py index debd7586952c..0b75a48ca66c 100755 --- a/python/tvm/relay/frontend/common.py +++ b/python/tvm/relay/frontend/common.py @@ -645,6 +645,10 @@ def unbind(data, axis=0): The sequence of computed tensors """ shape = infer_shape(data) + if axis >= len(shape): + msg = "Please check input dim, it shouldn't be greater than or equal to rank." + raise AttributeError(msg) + selections = shape[axis] res_split = _op.split(data, selections, axis) ret = [] diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index be62d920f9ff..3417d9c5247e 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -39,7 +39,7 @@ from ..prelude import Prelude, StaticTensorArrayOps from ..ty import Any, TensorType, TupleType from . import qnn_torch -from .common import AttrCvt, get_relay_op +from .common import AttrCvt, get_relay_op, unbind from .common import infer_value as _infer_value from .common import infer_shape as _infer_shape from .common import infer_value_simulated as _infer_value_simulated @@ -2092,24 +2092,6 @@ def deform_conv2d(self, inputs, input_types): kernel_size, ) - def unbind(self, inputs, input_types): - data = inputs[0] - dim = int(inputs[1]) - ishapes = self.infer_shape(data) - if dim >= len(ishapes): - msg = "Please check input dim, it shouldn't be greater than or equal to rank." - raise AttributeError(msg) - - selections = ishapes[dim] - res_split = _op.split(data, selections, dim) - # squeeze each split piece to get same shape as aten::unbind - # TODO (yongwww): add new op to avoid the squeeze overhead - ret = [] - for i in range(selections): - ret.append(_op.transform.squeeze(res_split[i], axis=[dim])) - ret = _expr.TupleWrapper(_expr.Tuple(ret), selections) - return ret - def shape_as_tensor(self, inputs, input_types): is_symbolic_shape = False input_shape = self.infer_shape(inputs[0], self.prelude.mod) @@ -2135,7 +2117,7 @@ def nonzero(self, inputs, input_types, is_numpy_style=False): data = inputs[0] ret = _op.transform.argwhere(data) if is_numpy_style or (len(inputs) > 1 and inputs[1]): - return self.unbind([ret, 1], None) + return unbind(ret, 1) return ret def nonzero_numpy(self, inputs, input_types): @@ -2364,7 +2346,7 @@ def lstm_layers(self, input_data, layer_weights_dicts, bidirectional, dtype, dro """ layers_num = len(layer_weights_dicts) # split input sequence to samples set - input_seqs = self.unbind((input_data, 0), dtype) # [seq_num, (batch, feature_size)] + input_seqs = unbind(input_data, 0) # [seq_num, (batch, feature_size)] output_hiddens = [] for i in range(layers_num): weights_dicts = layer_weights_dicts[i] @@ -2497,13 +2479,13 @@ def lstm(self, inputs, input_types): for i in range(hidden_layers_num): layers_h.append(h_0) else: - layers_h = self.unbind((h_0, 0), X_dtype) + layers_h = unbind(h_0, 0) if c_0 is None: c_0 = _op.zeros((batch_size, hidden_size), X_dtype) for i in range(hidden_layers_num): layers_c.append(c_0) else: - layers_c = self.unbind((c_0, 0), X_dtype) + layers_c = unbind(c_0, 0) layer_weights_dicts = [] k = 0 # layer counter @@ -2792,7 +2774,7 @@ def create_convert_map(self): "aten::logsumexp": self.logsumexp, "torchvision::roi_align": self.roi_align, "torchvision::deform_conv2d": self.deform_conv2d, - "aten::unbind": self.unbind, + "aten::unbind": unbind, "aten::__and__": self.logical_and, "aten::logical_and": self.logical_and, "aten::_shape_as_tensor": self.shape_as_tensor, From 4fc6505705922795a2dc3c0b120fc3165af4aa1e Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Tue, 3 Aug 2021 16:04:49 +0300 Subject: [PATCH 16/21] lstm cell was transferred from op/layers.py to frontend/common.py --- python/tvm/relay/frontend/common.py | 94 ++++++++++++++++++++++ python/tvm/relay/frontend/onnx.py | 7 +- python/tvm/relay/frontend/pytorch.py | 8 +- python/tvm/relay/op/__init__.py | 1 - python/tvm/relay/op/layers.py | 115 --------------------------- 5 files changed, 102 insertions(+), 123 deletions(-) delete mode 100644 python/tvm/relay/op/layers.py diff --git a/python/tvm/relay/frontend/common.py b/python/tvm/relay/frontend/common.py index 0b75a48ca66c..85005b8302b4 100755 --- a/python/tvm/relay/frontend/common.py +++ b/python/tvm/relay/frontend/common.py @@ -656,3 +656,97 @@ def unbind(data, axis=0): ret.append(_op.squeeze(res_split[i], axis=[axis])) ret = _expr.TupleWrapper(_expr.Tuple(ret), selections) return ret + + +def lstm_cell( + input_seqs, + hidden_state, + cell_state, + w_inp, + w_hid, + b_inp=None, + b_hid=None, + proj=None, + p_i=None, + p_f=None, + p_o=None, + f_act=_op.sigmoid, + g_act=_op.tanh, + h_act=_op.tanh, + backwards=False, +): + """ + Common implementation of LSTM cell for all frontends of TVM + TODO (vvchernov): currently it is used by onnx and pytorch. Extend for other frontends + + Parameters + ---------- + input_seqs : List[relay.Expr] + The sequence of input tensors + Input tensor should be 2d while issue #8412 is not resolved + Shape = (batch, feature_size) + hidden_state : relay.Expr + Hidden state. shape = (batch, hidden_size) + cell_state : relay.Expr + Cell state. shape = (batch, hidden_size) + w_inp, w_hid : relay.Expr + weight matrices. wi shape = (4 * hidden_size, feature_size) + wh shape = (4 * hidden_size, hidden_size or proj_size) + NOTE: wi = (w_ii|w_if|w_ig|w_io) for input, forget, cell and output gates. + The order is important for correct LSTM calculation! + b_inp, b_hid : relay.Expr + bias matrices. The same order of internal parts as for weights. shape = (4 * hidden_size) + proj : relay.Expr + projection matrix. shape = (proj_size, hidden_size) + p_i, p_f, p_o : relay.Expr + peephole LSTM matrices. shape = (batch, hidden_size) + f_act, g_act, h_act : relay.op + activation funtions + backwards : bool + Flag for reverse pass of LSTM + + Returns + ------- + result : List[relay.Expr], relay.Expr, relay.Expr + The sequence of computed result, final hidden and cell state + """ + + outputs_list = [] + for x_t in input_seqs if not backwards else reversed(input_seqs): + # x_t shape = (batch, feature size), step shape = (batch, feature size + hidden_size) + step = _op.concatenate([x_t, hidden_state], axis=1) + cat_w = _op.concatenate([w_inp, w_hid], axis=1) + # Instead of nn.dense(x_t, w_inp) + nn.dense(hidden_state, w_hid) + # the nn.dense(step, cat_w) is used + # gates shape = (batch, 4 * hidden_size) + gates = _op.nn.dense(step, cat_w) + # Add biases + if b_inp is not None: + gates += b_inp + if b_hid is not None: + gates += b_hid + # any gate shape = (batch, hidden_size) + inp_gate, fgt_gate, cell_gate, otp_gate = _op.split(gates, 4, axis=-1) + + if p_i is not None and p_f is not None: + inp_gate = f_act(inp_gate + p_i * cell_state) + fgt_gate = f_act(fgt_gate + p_f * cell_state) + else: + inp_gate = f_act(inp_gate) + fgt_gate = f_act(fgt_gate) + + cell_gate = g_act(cell_gate) + cell_state = fgt_gate * cell_state + inp_gate * cell_gate + if p_o is not None: + otp_gate = f_act(otp_gate + p_o * cell_state) + else: + otp_gate = f_act(otp_gate) + + hidden_state = otp_gate * h_act(cell_state) + + if proj is not None: + hidden_state = _op.nn.dense(hidden_state, proj) + + outputs_list.append(hidden_state) # [seq_num, (batch, hidden_size)] + + return outputs_list, hidden_state, cell_state diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 8f863aacab9a..adbbaf9ce885 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -47,6 +47,7 @@ infer_value, new_var, unbind, + lstm_cell, ) __all__ = ["from_onnx"] @@ -2153,7 +2154,7 @@ def bidir_lstm_cell( Bidirectional LSTM cell """ seq_len = len(input_seqs) - forward_outputs, fw_H_t, fw_C_t = _op.lstm_cell( + forward_outputs, fw_H_t, fw_C_t = lstm_cell( input_seqs, **weight_dicts[0], f_act=acts[0], @@ -2161,7 +2162,7 @@ def bidir_lstm_cell( h_act=acts[2], ) - reverse_outputs, rev_H_t, rev_C_t = _op.lstm_cell( + reverse_outputs, rev_H_t, rev_C_t = lstm_cell( input_seqs, **weight_dicts[1], f_act=acts[3], @@ -2290,7 +2291,7 @@ def _impl_v7(cls, inputs, attr, params): ) else: # outputs shape = [seqs_num, (batch_size, hidden_size)] - outputs, H, C = _op.lstm_cell( + outputs, H, C = lstm_cell( input_seqs=X_steps, **weights_dicts[0], f_act=acts[0], diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 3417d9c5247e..fbeacdddde05 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -39,7 +39,7 @@ from ..prelude import Prelude, StaticTensorArrayOps from ..ty import Any, TensorType, TupleType from . import qnn_torch -from .common import AttrCvt, get_relay_op, unbind +from .common import AttrCvt, get_relay_op, unbind, lstm_cell from .common import infer_value as _infer_value from .common import infer_shape as _infer_shape from .common import infer_value_simulated as _infer_value_simulated @@ -2321,12 +2321,12 @@ def bidir_lstm_cell( Bidirectional LSTM cell """ seq_len = len(input_seqs) - forward_outputs, fw_H_t, fw_C_t = _op.lstm_cell( + forward_outputs, fw_H_t, fw_C_t = lstm_cell( input_seqs, **weights_dicts[0], ) - reverse_outputs, rev_H_t, rev_C_t = _op.lstm_cell( + reverse_outputs, rev_H_t, rev_C_t = lstm_cell( input_seqs, **weights_dicts[1], backwards=True, @@ -2355,7 +2355,7 @@ def lstm_layers(self, input_data, layer_weights_dicts, bidirectional, dtype, dro if bidirectional: input_seqs, H_t, C_t = self.bidir_lstm_cell(input_seqs, weights_dicts) else: - input_seqs, H_t, C_t = _op.lstm_cell(input_seqs, **weights_dicts[0]) + input_seqs, H_t, C_t = lstm_cell(input_seqs, **weights_dicts[0]) output_hiddens.append((H_t, C_t)) diff --git a/python/tvm/relay/op/__init__.py b/python/tvm/relay/op/__init__.py index 7412042b00ca..2e509a111c4a 100644 --- a/python/tvm/relay/op/__init__.py +++ b/python/tvm/relay/op/__init__.py @@ -38,7 +38,6 @@ from .tensor import * from .transform import * from .algorithm import * -from .layers import * from . import vm from . import nn from . import annotation diff --git a/python/tvm/relay/op/layers.py b/python/tvm/relay/op/layers.py deleted file mode 100644 index 9acf223f8240..000000000000 --- a/python/tvm/relay/op/layers.py +++ /dev/null @@ -1,115 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""Intermediate representation of complicated layers unified for all frontends""" - -from .tensor import sigmoid, tanh, concatenate -from .transform import split -from . import nn - - -def lstm_cell( - input_seqs, - hidden_state, - cell_state, - w_inp, - w_hid, - b_inp=None, - b_hid=None, - proj=None, - p_i=None, - p_f=None, - p_o=None, - f_act=sigmoid, - g_act=tanh, - h_act=tanh, - backwards=False, -): - """ - Common implementation of LSTM cell for all frontends of TVM - TODO (vvchernov): currently it is used by onnx and pytorch. - - Parameters - ---------- - input_seqs : List[relay.Expr] - The sequence of input tensors - Input tensor should be 2d while issue #8412 is not resolved - Shape = (batch, feature_size) - hidden_state : relay.Expr - Hidden state. shape = (batch, hidden_size) - cell_state : relay.Expr - Cell state. shape = (batch, hidden_size) - w_inp, w_hid : relay.Expr - weight matrices. wi shape = (4 * hidden_size, feature_size) - wh shape = (4 * hidden_size, hidden_size or proj_size) - NOTE: wi = (w_ii|w_if|w_ig|w_io) for input, forget, cell and output gates. - The order is important for correct LSTM calculation! - b_inp, b_hid : relay.Expr - bias matrices. The same order of internal parts as for weights. shape = (4 * hidden_size) - proj : relay.Expr - projection matrix. shape = (proj_size, hidden_size) - p_i, p_f, p_o : relay.Expr - peephole LSTM matrices. shape = (batch, hidden_size) - f_act, g_act, h_act : relay.op - activation funtions - backwards : bool - Flag for reverse pass of LSTM - - Returns - ------- - result : List[relay.Expr], relay.Expr, relay.Expr - The sequence of computed result, final hidden and cell state - """ - - outputs_list = [] - for x_t in input_seqs if not backwards else reversed(input_seqs): - # x_t shape = (batch, feature size), step shape = (batch, feature size + hidden_size) - step = concatenate([x_t, hidden_state], axis=1) - cat_w = concatenate([w_inp, w_hid], axis=1) - # Instead of nn.dense(x_t, w_inp) + nn.dense(hidden_state, w_hid) - # the nn.dense(step, cat_w) is used - # gates shape = (batch, 4 * hidden_size) - gates = nn.dense(step, cat_w) - # Add biases - if b_inp is not None: - gates += b_inp - if b_hid is not None: - gates += b_hid - inp_gate, fgt_gate, cell_gate, otp_gate = split(gates, 4, axis=-1) # (batch, hidden_size) - - if p_i is not None and p_f is not None: - inp_gate = f_act(inp_gate + p_i * cell_state) - fgt_gate = f_act(fgt_gate + p_f * cell_state) - else: - inp_gate = f_act(inp_gate) - fgt_gate = f_act(fgt_gate) - - cell_gate = g_act(cell_gate) - cell_state = fgt_gate * cell_state + inp_gate * cell_gate - if p_o is not None: - otp_gate = f_act(otp_gate + p_o * cell_state) - else: - otp_gate = f_act(otp_gate) - - hidden_state = otp_gate * h_act(cell_state) - - if proj is not None: - hidden_state = nn.dense(hidden_state, proj) - - outputs_list.append(hidden_state) # [seq_num, (batch, hidden_size)] - - return outputs_list, hidden_state, cell_state From e103b20ca9294cb959ed3821fb27d530b7973c03 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Tue, 3 Aug 2021 17:08:56 +0300 Subject: [PATCH 17/21] clean up weight dictionary initialization --- python/tvm/relay/frontend/pytorch.py | 50 ++++++++-------------------- 1 file changed, 14 insertions(+), 36 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index fbeacdddde05..4e06f15cedda 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -2490,27 +2490,18 @@ def lstm(self, inputs, input_types): layer_weights_dicts = [] k = 0 # layer counter if has_biases: + names = ["hidden_state", "cell_state", "w_inp", "w_hid", "b_inp", "b_hid"] if bidirectional: rsd = len(_weights) % (2 * weights_num) assert rsd == 0, "got an incorrect number of LSTM weights" for i in range(0, len(_weights), 2 * weights_num): - fw_weights_dict = {} - fw_weights_dict["hidden_state"] = layers_h[2 * k] - fw_weights_dict["cell_state"] = layers_c[2 * k] - fw_weights_dict["w_inp"] = _weights[i] - fw_weights_dict["w_hid"] = _weights[i + 1] - fw_weights_dict["b_inp"] = _weights[i + 2] - fw_weights_dict["b_hid"] = _weights[i + 3] + fw_tensors = [layers_h[2 * k], layers_c[2 * k], *_weights[i : i + 4]] + fw_weights_dict = dict(zip(names, fw_tensors)) if has_proj: fw_weights_dict["proj"] = _weights[i + 4] - rev_weights_dict = {} j = i + weights_num - rev_weights_dict["hidden_state"] = layers_h[2 * k + 1] - rev_weights_dict["cell_state"] = layers_c[2 * k + 1] - rev_weights_dict["w_inp"] = _weights[j] - rev_weights_dict["w_hid"] = _weights[j + 1] - rev_weights_dict["b_inp"] = _weights[j + 2] - rev_weights_dict["b_hid"] = _weights[j + 3] + rev_tensors = [layers_h[2 * k + 1], layers_c[2 * k + 1], *_weights[j : j + 4]] + rev_weights_dict = dict(zip(names, rev_tensors)) if has_proj: rev_weights_dict["proj"] = _weights[j + 4] layer_weights_dicts.append([fw_weights_dict, rev_weights_dict]) @@ -2518,35 +2509,25 @@ def lstm(self, inputs, input_types): else: assert len(_weights) % weights_num == 0, "got an incorrect number of LSTM weights" for i in range(0, len(_weights), weights_num): - fw_weights_dict = {} - fw_weights_dict["hidden_state"] = layers_h[k] - fw_weights_dict["cell_state"] = layers_c[k] - fw_weights_dict["w_inp"] = _weights[i] - fw_weights_dict["w_hid"] = _weights[i + 1] - fw_weights_dict["b_inp"] = _weights[i + 2] - fw_weights_dict["b_hid"] = _weights[i + 3] + fw_tensors = [layers_h[k], layers_c[k], *_weights[i : i + 4]] + fw_weights_dict = dict(zip(names, fw_tensors)) if has_proj: fw_weights_dict["proj"] = _weights[i + 4] layer_weights_dicts.append([fw_weights_dict]) k += 1 else: + names = ["hidden_state", "cell_state", "w_inp", "w_hid"] if bidirectional: rsd = len(_weights) % (2 * weights_num) assert rsd == 0, "got an incorrect number of LSTM weights" for i in range(0, len(_weights), 2 * weights_num): - fw_weights_dict = {} - fw_weights_dict["hidden_state"] = layers_h[2 * k] - fw_weights_dict["cell_state"] = layers_c[2 * k] - fw_weights_dict["w_inp"] = _weights[i] - fw_weights_dict["w_hid"] = _weights[i + 1] + fw_tensors = [layers_h[2 * k], layers_c[2 * k], *_weights[i : i + 2]] + fw_weights_dict = dict(zip(names, fw_tensors)) if has_proj: fw_weights_dict["proj"] = _weights[i + 2] - rev_weights_dict = {} j = i + weights_num - rev_weights_dict["hidden_state"] = layers_h[2 * k + 1] - rev_weights_dict["cell_state"] = layers_c[2 * k + 1] - rev_weights_dict["w_inp"] = _weights[j] - rev_weights_dict["w_hid"] = _weights[j + 1] + rev_tensors = [layers_h[2 * k + 1], layers_c[2 * k + 1], *_weights[j : j + 2]] + rev_weights_dict = dict(zip(names, rev_tensors)) if has_proj: rev_weights_dict["proj"] = _weights[j + 2] layer_weights_dicts.append([fw_weights_dict, rev_weights_dict]) @@ -2554,11 +2535,8 @@ def lstm(self, inputs, input_types): else: assert len(_weights) % weights_num == 0, "got an incorrect number of LSTM weights" for i in range(0, len(_weights), weights_num): - fw_weights_dict = {} - fw_weights_dict["hidden_state"] = layers_h[k] - fw_weights_dict["cell_state"] = layers_c[k] - fw_weights_dict["w_inp"] = _weights[i] - fw_weights_dict["w_hid"] = _weights[i + 1] + fw_tensors = [layers_h[k], layers_c[k], *_weights[i : i + 2]] + fw_weights_dict = dict(zip(names, fw_tensors)) if has_proj: fw_weights_dict["proj"] = _weights[i + 2] layer_weights_dicts.append([fw_weights_dict]) From d3df533fc253586386420581c2a5372737a09447 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Tue, 3 Aug 2021 21:51:02 +0300 Subject: [PATCH 18/21] fix pytorch frontend wrapper over unbind method --- python/tvm/relay/frontend/common.py | 3 +-- python/tvm/relay/frontend/pytorch.py | 7 ++++++- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/python/tvm/relay/frontend/common.py b/python/tvm/relay/frontend/common.py index 85005b8302b4..f67a05825212 100755 --- a/python/tvm/relay/frontend/common.py +++ b/python/tvm/relay/frontend/common.py @@ -654,8 +654,7 @@ def unbind(data, axis=0): ret = [] for i in range(selections): ret.append(_op.squeeze(res_split[i], axis=[axis])) - ret = _expr.TupleWrapper(_expr.Tuple(ret), selections) - return ret + return _expr.TupleWrapper(_expr.Tuple(ret), selections) def lstm_cell( diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 4e06f15cedda..321bdc6d62e5 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -2092,6 +2092,11 @@ def deform_conv2d(self, inputs, input_types): kernel_size, ) + def unbind(self, inputs, input_types): + data = inputs[0] + axis = int(inputs[1]) + return unbind(data, axis) + def shape_as_tensor(self, inputs, input_types): is_symbolic_shape = False input_shape = self.infer_shape(inputs[0], self.prelude.mod) @@ -2752,7 +2757,7 @@ def create_convert_map(self): "aten::logsumexp": self.logsumexp, "torchvision::roi_align": self.roi_align, "torchvision::deform_conv2d": self.deform_conv2d, - "aten::unbind": unbind, + "aten::unbind": self.unbind, "aten::__and__": self.logical_and, "aten::logical_and": self.logical_and, "aten::_shape_as_tensor": self.shape_as_tensor, From 4ce5982017cd376a133cbf4f258f51c4b6f361db Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Wed, 4 Aug 2021 09:27:46 +0300 Subject: [PATCH 19/21] minor fix of comments --- python/tvm/relay/frontend/common.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/tvm/relay/frontend/common.py b/python/tvm/relay/frontend/common.py index f67a05825212..7f67ed404de9 100755 --- a/python/tvm/relay/frontend/common.py +++ b/python/tvm/relay/frontend/common.py @@ -628,8 +628,8 @@ def to_int_list(np_array): def unbind(data, axis=0): """ - Unbind was gotten from pytorch.py and modified. The operation removes a tensor dimension - and returns a tuple of all slices along a given dimension, already without it. + Unbind was taken from Pytorch frontend. The operation removes a tensor dimension + and returns a tuple of all slices along a given dimension, with specified axis removed. TODO (vvchernov): It needs such operation on relay side to reduce time consumption on squeeze operation. @@ -638,7 +638,7 @@ def unbind(data, axis=0): data : relay.Expr Input tensor axis : int - Axis along which tensor is splited. Tensors in the list do not have this axis. + Axis along which tensor is split. Returns ------- result : List[relay.Expr] @@ -716,7 +716,7 @@ def lstm_cell( step = _op.concatenate([x_t, hidden_state], axis=1) cat_w = _op.concatenate([w_inp, w_hid], axis=1) # Instead of nn.dense(x_t, w_inp) + nn.dense(hidden_state, w_hid) - # the nn.dense(step, cat_w) is used + # nn.dense(step, cat_w) is used # gates shape = (batch, 4 * hidden_size) gates = _op.nn.dense(step, cat_w) # Add biases From 47e1bb22df8bcd8f64e396cbaba5e83425295209 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Wed, 4 Aug 2021 16:18:51 +0300 Subject: [PATCH 20/21] empty. tvm-ci test restart From 3ed483625a70cc96997da3916a0c58da828cffd5 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Wed, 4 Aug 2021 21:32:36 +0300 Subject: [PATCH 21/21] empty. tvm-ci test restart