diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 4c874672445b..33cb83b883bc 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -41,6 +41,7 @@ from . import qnn_torch from .common import AttrCvt, get_relay_op from .common import infer_value as _infer_value +from .common import infer_shape as _infer_shape from .common import infer_value_simulated as _infer_value_simulated from .common import try_infer_value from .pytorch_utils import is_version_greater_than @@ -2329,6 +2330,298 @@ def flip(self, inputs, input_types): axis = inputs[1] return _op.transform.reverse(data, axis=axis[0]) + def lstm_cell(self, input_seqs, hidden, weights, has_proj=False): + if has_proj: + assert len(weights) == 5 + else: + assert len(weights) == 4 + outputs_list = [] + # Default activations types + f_act = _op.sigmoid + g_act = _op.tanh + h_act = _op.tanh + + # Input hiddens + H_t = hidden[0] # (batch, hidden_size) + C_t = hidden[1] # (batch, hidden_size) + for x_t in input_seqs: + # x_t shape = (batch, feature size) + # gates shape = (batch, 4 * hidden_size) + gates = _op.nn.dense(x_t, weights[0]) + _op.nn.dense(H_t, weights[1]) + # Add biases + if weights[2] is not None: + gates += weights[2] + if weights[3] is not None: + gates += weights[3] + i, f, c, o = _op.split(gates, 4, axis=-1) # (batch, hidden_size) + + i = f_act(i) + f = f_act(f) + c = g_act(c) + o = f_act(o) + + C = f * C_t + i * c + H = o * h_act(C) + + if has_proj: + H = _op.nn.dense(H, weights[4]) + + H_t = H + C_t = C + outputs_list.append(H) # [seq_num, (batch, hidden_size)] + hidden_outputs = (H_t, C_t) + + return (outputs_list, hidden_outputs) + + def bidir_lstm_cell(self, input_seq, hidden_pair, weights_pair, has_proj=False): + fw_outputs = self.lstm_cell(input_seq, hidden_pair[0], weights_pair[0], has_proj) + + rev_input_seq = [] + seq_len = len(input_seq) + for i in range(seq_len): + rev_input_seq.append(input_seq[seq_len - 1 - i]) # [seq_num, (batch, hidden_size)] + rev_outputs = self.lstm_cell(rev_input_seq, hidden_pair[1], weights_pair[1], has_proj) + + final_outputs = [] # [seq_num, (batch, 2 * hidden_size)] + for j in range(seq_len): + final_outputs.append( + _op.concatenate([fw_outputs[0][j], rev_outputs[0][seq_len - 1 - j]], -1) + ) + + return final_outputs, (fw_outputs[1], rev_outputs[1]) + + def lstm_layers( + self, input_data, hiddens, weights, bidirectional, dtype, dropout_p=0.0, has_proj=False + ): + hidden_layers_num = len(hiddens) + assert len(weights) == hidden_layers_num + + # split input sequence to samples set + input_seqs = self.unbind((input_data, 0), dtype) # [seq_num, (batch, feature_size)] + output_hiddens = [] + for k in range(hidden_layers_num): + hiddens_input = hiddens[k] + weights_input = weights[k] + + outputs = ( + self.bidir_lstm_cell(input_seqs, hiddens_input, weights_input, has_proj) + if bidirectional + else self.lstm_cell(input_seqs, hiddens_input, weights_input, has_proj) + ) + + output_hiddens.append(outputs[1]) + # input_seqs shape = [seq_num, (batch, feature_size)] or + # [seq_num, (batch, 2*feature_size)] for bidirectional + input_seqs = outputs[0] + + # TODO (vvchernov): in pytorch implementation train is also checked + # see https://github.com/pytorch/pytorch/blob/70c8daf43946b53af6493d058899ef952d27d339 + # /aten/src/ATen/native/RNN.cpp#L1054 + if dropout_p != 0 and k < hidden_layers_num - 1: + # for input in input_seqs: + # input = _op.dropout(input, dropout_p) + raise NotImplementedError("Dropout for LSTM has not been supported yet!") + final_hiddens = [] + if bidirectional: + for i in range(hidden_layers_num): + final_hiddens.append(output_hiddens[i][0]) + final_hiddens.append(output_hiddens[i][1]) + else: + final_hiddens = output_hiddens + + return _op.stack(input_seqs, 0), final_hiddens + + def lstm(self, inputs, input_types): + """ + Description of LSTM in pytorch:https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html + Native implementation for torch version less than 1.8.0 (projection is unsupported): + https://github.com/pytorch/pytorch/blob/70c8daf43946b53af6493d058899ef952d27d339/aten/ \ + src/ATen/native/RNN.cpp#L1396 + Native implementation for torch version from 1.8.0 and higher (projection is supported): + https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/RNN.cpp#L1483 + """ + # TODO (vvchernov): support dropout + assert len(inputs) == 9, "Input of size 9 is expected" + # Unpack inputs, note that if optional and not provided then value will be None. + _X = inputs[0] + # _X shape (seq_num, batch, feature_size) or (batch, seq_num, feature_size) + + hidden_states = inputs[1] + assert len(hidden_states) == 2, "lstm expects two hidden states" + h_0 = hidden_states[0] + c_0 = hidden_states[1] + # H0 shape (hidden_layers_num, batch, proj_size) if projection + # else (hidden_layers_num, batch, hidden_size) + # C0 shape (hidden_layers_num, batch, hidden_size) + + _weights = inputs[2] + # If no projection + # Wi layer[0] shape (4 * hidden_size, feature_size) + # Wh layer[0] shape (4 * hidden_size, hidden_size) + # Bi layer[0] shape (4 * hidden_size) + # Bh layer[0] shape (4 * hidden_size) + + # Wi layer[>0] shape (4 * hidden_size, hidden_size * num_directions) + # Wh layer[>0] shape (4 * hidden_size, hidden_size) + # Bi layer[>0] shape (4 * hidden_size) + # Bh layer[>0] shape (4 * hidden_size) + + # If projection + # Wi layer[0] shape (4 * hidden_size, feature_size) + # Wh layer[0] shape (4 * hidden_size, proj_size) + # Bi layer[0] shape (4 * hidden_size) + # Bh layer[0] shape (4 * hidden_size) + # P layer[0] shape (proj_size, hidden_size) + + # Wi layer[>0] shape (4 * hidden_size, proj_size * num_directions) + # Wh layer[>0] shape (4 * hidden_size, proj_size) + # Bi layer[>0] shape (4 * hidden_size) + # Bh layer[>0] shape (4 * hidden_size) + # P layer[>0] shape (proj_size, hidden_size) + + # Scalar inputs + has_biases = inputs[3] + num_layers = inputs[4] + dropout_p = inputs[5] # dropout probability, if 0.0 it means there is no dropout + # train = inputs[6] + bidirectional = inputs[7] + batch_first = inputs[8] + + num_directions = 1 + if bidirectional: + num_directions = 2 + + rsd = len(_weights) % num_layers + assert rsd == 0, "The number of weights must be a multiple of the number of layers!" + rsd = (len(_weights) / num_layers) % num_directions + assert ( + rsd == 0 + ), "The number of weights in layer must be a multiple of the number of directions!" + has_proj = False + proj_size = 0 + weights_num = int(len(_weights) / num_layers / num_directions) + if has_biases: + if weights_num == 5: + has_proj = True + proj_size = _infer_shape(_weights[4])[0] + else: + assert weights_num == 4, "The weights number in layer is expected equal to 4" + else: + if weights_num == 3: + has_proj = True + proj_size = _infer_shape(_weights[2])[0] + else: + assert weights_num == 2, "The weights number in layer is expected equal to 2" + + weights = [] + if has_biases: + if bidirectional: + rsd = len(_weights) % (2 * weights_num) + assert rsd == 0, "got an incorrect number of LSTM weights" + for i in range(0, len(_weights), 2 * weights_num): + fw_weights = [] + rev_weights = [] + for j in range(weights_num): + fw_weights.append(_weights[i + j]) + rev_weights.append(_weights[i + j + weights_num]) + weights.append((fw_weights, rev_weights)) + else: + assert len(_weights) % weights_num == 0, "got an incorrect number of LSTM weights" + for i in range(0, len(_weights), weights_num): + fw_weights = [] + for j in range(weights_num): + fw_weights.append(_weights[i + j]) + weights.append(fw_weights) + else: + if bidirectional: + rsd = len(_weights) % (2 * weights_num) + assert rsd == 0, "got an incorrect number of LSTM weights" + for i in range(0, len(_weights), 2 * weights_num): + fw_weights = [] + rev_weights = [] + k = i + weights_num + if has_proj: + fw_weights = [_weights[i], _weights[i + 1], None, None, _weights[i + 2]] + rev_weights = [_weights[k], _weights[k + 1], None, None, _weights[k + 2]] + else: + fw_weights = [_weights[i], _weights[i + 1], None, None] + rev_weights = [_weights[k], _weights[k + 1], None, None] + weights.append((fw_weights, rev_weights)) + else: + assert len(_weights) % weights_num == 0, "got an incorrect number of LSTM weights" + for i in range(0, len(_weights), weights_num): + if has_proj: + fw_weights = [_weights[i], _weights[i + 1], None, None, _weights[i + 2]] + else: + fw_weights = [_weights[i], _weights[i + 1], None, None] + weights.append(fw_weights) + assert ( + len(weights) == num_layers + ), "For stacked LSTM number of weights tuples should be the same as number of layers!" + + X = _op.transpose(_X, (1, 0, 2)) if batch_first else _X + # TODO (vvchernov): Which data type should be used? from input or weights? + # Instead of it _infer_type(X).checked_type.dtype can be used + X_dtype = input_types[0] + X_shape = _infer_shape(X) # (seq_num, batch, feature_size) + + hidden_size = _infer_shape(_weights[0])[0] / 4 + batch_size = X_shape[1] + + # Initialize hidden states if not provided. + layers_h = [] + layers_c = [] + hidden_layers_num = num_directions * num_layers + if h_0 is None: + if has_proj: + h_0 = _op.zeros((batch_size, proj_size), X_dtype) + else: + h_0 = _op.zeros((batch_size, hidden_size), X_dtype) + for i in range(hidden_layers_num): + layers_h.append(h_0) + else: + layers_h = self.unbind((h_0, 0), X_dtype) + if c_0 is None: + c_0 = _op.zeros((batch_size, hidden_size), X_dtype) + for i in range(hidden_layers_num): + layers_c.append(c_0) + else: + layers_c = self.unbind((c_0, 0), X_dtype) + + hiddens = [] + for i in range(num_layers): + if bidirectional: + hiddens.append( + ((layers_h[2 * i], layers_c[2 * i]), (layers_h[2 * i + 1], layers_c[2 * i + 1])) + ) + else: + hiddens.append((layers_h[i], layers_c[i])) + + outputs = self.lstm_layers( + X, + hiddens, + weights, + bidirectional, + dtype=X_dtype, + dropout_p=dropout_p, + has_proj=has_proj, + ) + + # output shape = (seq_num, batch, hidden_size) or + # (seq_num, batch, 2*feature_size) for bidirectional + output = outputs[0] + + hy = [] + cy = [] + for hidden in outputs[1]: + hy.append(hidden[0]) + cy.append(hidden[1]) + + if batch_first: + output = _op.transpose(output, (1, 0, 2)) + + return (output, _op.stack(hy, 0), _op.stack(cy, 0)) + # Operator mappings def create_convert_map(self): self.convert_map = { @@ -2545,6 +2838,7 @@ def create_convert_map(self): "aten::nll_loss": self.nll_loss, "aten::nll_loss2d": self.nll_loss, "aten::flip": self.flip, + "aten::lstm": self.lstm, } def update_convert_map(self, custom_map): diff --git a/tests/python/frontend/pytorch/test_lstms.py b/tests/python/frontend/pytorch/test_lstms.py new file mode 100644 index 000000000000..e780ae725b74 --- /dev/null +++ b/tests/python/frontend/pytorch/test_lstms.py @@ -0,0 +1,363 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +import tvm.testing +import numpy as np +import torch +import onnx +import io +import sys +import pytest + +from tvm import relay +from tvm.contrib import graph_executor + +from torch import nn + +## Model parameters +model_feature_size = 5 +model_hidden_size = 10 +model_num_layers = 2 +seqs_length = 15 +projection_size = 7 +batch_size = 3 + + +def check_torch_version_for_proj_in_lstm(): + """ + proj_size parameter is supported in torch.nn.LSTM layer started from 1.8.0 torch version + """ + me = False + + version = torch.__version__ + major, minor, micro = version.split(".") + + if int(major) > 1: + me = True + elif int(major) == 1: + if int(minor) >= 8: + me = True + + return me + + +class LSTM_Model(nn.Module): + def __init__( + self, + device, + batch_first=False, + layer_num=1, + bidirectional=False, + proj_size=0, + use_bias=True, + rnd_weights_init=False, + ): + super().__init__() + + self.device = device + self.batch_first = batch_first + self.use_bias = use_bias + + if check_torch_version_for_proj_in_lstm(): + self.lstm = nn.LSTM( + input_size=model_feature_size, + hidden_size=model_hidden_size, + num_layers=layer_num, + bidirectional=bidirectional, + proj_size=proj_size, + batch_first=batch_first, + bias=use_bias, + ).to(device) + else: + if proj_size > 0: + print( + "WARNING: projection is not supported for torch version less than 1.8.0! ", + "LSTM was constructed without projection!", + ) + # sys.exit() + self.lstm = nn.LSTM( + input_size=model_feature_size, + hidden_size=model_hidden_size, + num_layers=layer_num, + bidirectional=bidirectional, + batch_first=batch_first, + bias=use_bias, + ).to(device) + + if rnd_weights_init: + self.gen_rnd_weights() + + def forward(self, input, hidden_init=None): + """ + Computes the output tensor after input inference along LSTM layer. + + :param input: batch of data as a tensor of shape (seqs_length, batch_size, model_feature_size) or (batch_size, seqs_length, model_feature_size) if self.batch_first = True + :param hidden_init: initial hidden state of the LSTM as a tensor of shape (num_layers, batch_size, hidden_size). Will default to a tensor of zeros if None. + :return: the output tensor of shape (batch_size, model_hidden_size) + """ + # Pass the input through the LSTM layers and retrieve all outputs, the final hidden state + # and the final cell state. + out, (hidden, cell) = self.lstm(input, hidden_init) + + return out + + def gen_rnd_weights(self): + """ + Generate random weigths for the model with biases + Without projection: + For first weights group: + Wi (4*model_hidden_size, model_feature_size) + Wh (4*model_hidden_size, model_hidden_size) + Bi (4*model_hidden_size) + Bh (4*model_hidden_size) + For first bidirectional weights group: + Wi (4*model_hidden_size, model_feature_size) + Wh (4*model_hidden_size, model_hidden_size) + Bi (4*model_hidden_size) + Bh (4*model_hidden_size) + For other weights group: + Wi (4*model_hidden_size, model_hidden_size) + Wh (4*model_hidden_size, model_hidden_size) + Bi (4*model_hidden_size) + Bh (4*model_hidden_size) + With projection: + For first weights group: + Wi (4*model_hidden_size, model_feature_size) + Wh (4*model_hidden_size, proj_size) + Bi (4*model_hidden_size) + Bh (4*model_hidden_size) + P (proj_size, model_hidden_size) + For first bidirectional weights group: + Wi (4*model_hidden_size, model_feature_size) + Wh (4*model_hidden_size, proj_size) + Bi (4*model_hidden_size) + Bh (4*model_hidden_size) + P (proj_size, model_hidden_size) + For other weights group: + Wi (4*model_hidden_size, proj_size * num_directions) + Wh (4*model_hidden_size, proj_size) + Bi (4*model_hidden_size) + Bh (4*model_hidden_size) + P (proj_size, model_hidden_size) + For generation of random weigths for the model without biases Bi and Bh are skipped + """ + for weight_group in self.lstm.all_weights: + for weight in weight_group: + weight.data = torch.rand(weight.shape) + + def get_dummy_input(self): + shape = [seqs_length, batch_size, model_feature_size] + if self.batch_first: + shape = [batch_size, seqs_length, model_feature_size] + res = torch.rand(shape) + + return res, shape + + +def compare(input, gold_data, rtol=1e-5, atol=1e-5): + tvm.testing.assert_allclose(input, gold_data, rtol=rtol, atol=atol) + + +def check_lstm_with_type( + lstm_type, target=tvm.target.Target("llvm -mcpu=core-avx2"), dev=tvm.cpu(0) +): + has_proj = "p" in lstm_type + + device = torch.device("cpu") + hidden_layers_num = 1 + model = None + for batch_first in (True, False): + for use_bias in (True, False): + for rnd_weights in (True, False): + if lstm_type == "uni": + model = LSTM_Model( + device, + batch_first=batch_first, + rnd_weights_init=rnd_weights, + use_bias=use_bias, + ) + elif lstm_type == "b": + model = LSTM_Model( + device, + batch_first=batch_first, + bidirectional=True, + rnd_weights_init=rnd_weights, + use_bias=use_bias, + ) + hidden_layers_num = 2 + elif lstm_type == "p": + model = LSTM_Model( + device, + batch_first=batch_first, + proj_size=projection_size, + rnd_weights_init=rnd_weights, + use_bias=use_bias, + ) + elif lstm_type == "s": + model = LSTM_Model( + device, + batch_first=batch_first, + layer_num=model_num_layers, + rnd_weights_init=rnd_weights, + use_bias=use_bias, + ) + hidden_layers_num = model_num_layers + elif lstm_type == "sb": + model = LSTM_Model( + device, + batch_first=batch_first, + bidirectional=True, + layer_num=model_num_layers, + rnd_weights_init=rnd_weights, + use_bias=use_bias, + ) + hidden_layers_num = 2 * model_num_layers + elif lstm_type == "sp": + model = LSTM_Model( + device, + batch_first=batch_first, + layer_num=model_num_layers, + proj_size=projection_size, + rnd_weights_init=rnd_weights, + use_bias=use_bias, + ) + hidden_layers_num = model_num_layers + elif lstm_type == "bp": + model = LSTM_Model( + device, + batch_first=batch_first, + bidirectional=True, + proj_size=projection_size, + rnd_weights_init=rnd_weights, + use_bias=use_bias, + ) + hidden_layers_num = 2 + elif lstm_type == "sbp": + model = LSTM_Model( + device, + batch_first=batch_first, + bidirectional=True, + layer_num=model_num_layers, + proj_size=projection_size, + rnd_weights_init=rnd_weights, + use_bias=use_bias, + ) + hidden_layers_num = 2 * model_num_layers + else: + print("WARNING: LSTM type {} is not supported here!".format(lstm_type)) + return + + model.eval() + + # Get golden output from original model + input_hidden_shape = (hidden_layers_num, batch_size, model_hidden_size) + input_hidden_shape_with_proj = (hidden_layers_num, batch_size, projection_size) + dummy_input, input_shape = model.get_dummy_input() + golden_output_batch = model.forward(dummy_input.to(device)).detach().cpu().numpy() + + dtype = "float32" + h_zeros = np.zeros(input_hidden_shape, dtype=dtype) + if has_proj: + h_zeros = np.zeros(input_hidden_shape_with_proj, dtype=dtype) + c_zeros = np.zeros(input_hidden_shape, dtype=dtype) + + tvm_output = None + for format in ("ts", "onnx"): + if format == "ts": + # Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing. + traced_script_module = torch.jit.trace(model, dummy_input).eval() + + # Import model to Relay + shape_list = [("input", input_shape)] + mod, params = relay.frontend.from_pytorch(traced_script_module, shape_list) + + # Model compilation by tvm + with tvm.transform.PassContext(opt_level=3): + lib = relay.build(mod, target=target, params=params) + elif format == "onnx": + if has_proj: + print( + "WARNING: torch.onnx.export does not support conversion LSTM with projection " + "from pytorch! TODO: waiting for the support and correct test after that." + ) + continue + onnx_io = io.BytesIO() + with torch.no_grad(): + h0 = torch.rand(input_hidden_shape) + if has_proj: + h0 = torch.rand(input_hidden_shape_with_proj) + c0 = torch.rand(input_hidden_shape) + input_names = ["input", "h0", "c0"] + + # default export (without dynamic input) + torch.onnx.export( + model, (dummy_input, (h0, c0)), onnx_io, input_names=input_names + ) + onnx_io.seek(0, 0) + onnx_model = onnx.load_model(onnx_io) + + # Import model to Relay + shape_dict = { + "input": input_shape, + "h0": input_hidden_shape, + "c0": input_hidden_shape, + } + if has_proj: + shape_dict = { + "input": input_shape, + "h0": input_hidden_shape_with_proj, + "c0": input_hidden_shape, + } + mod, params = relay.frontend.from_onnx(onnx_model, shape_dict) + + # Model compilation by tvm + with tvm.transform.PassContext(opt_level=1): + lib = relay.build(mod, target=target, params=params) + + # Inference of the model with given input data + m = graph_executor.GraphModule(lib["default"](dev)) + + # Set inputs + m.set_input( + input=tvm.nd.array(dummy_input.numpy().astype(dtype)), + h0=tvm.nd.array(h_zeros), + c0=tvm.nd.array(c_zeros), + ) + # Execute + m.run() + # Get outputs (converted to numpy array) + tvm_output = m.get_output(0).numpy() + + compare(tvm_output, golden_output_batch) + + +@tvm.testing.uses_gpu +def test_lstms(): + for target, dev in tvm.testing.enabled_targets(): + check_lstm_with_type("uni", target, dev) + check_lstm_with_type("p", target, dev) + check_lstm_with_type("s", target, dev) + check_lstm_with_type("b", target, dev) + check_lstm_with_type("bp", target, dev) + check_lstm_with_type("sp", target, dev) + check_lstm_with_type("sb", target, dev) + check_lstm_with_type("sbp", target, dev) + + +if __name__ == "__main__": + test_lstms()