diff --git a/python/tvm/relay/frontend/keras.py b/python/tvm/relay/frontend/keras.py index aa185923d02e..bf6293a2a90c 100644 --- a/python/tvm/relay/frontend/keras.py +++ b/python/tvm/relay/frontend/keras.py @@ -896,6 +896,7 @@ def _convert_lstm(inexpr, keras_layer, etab): in_data = _op.squeeze(in_data, axis=[0]) in_data = _op.split(in_data, indices_or_sections=time_steps, axis=0) # loop for the number of time_steps + out_list = [] # store h outputs in case return_sequences is True for data in in_data: ixh1 = _op.nn.dense(data, kernel_weight, units=units) ixh2 = _op.nn.bias_add(_op.nn.dense(next_h, recurrent_weight, units=units), bias=in_bias) @@ -906,8 +907,11 @@ def _convert_lstm(inexpr, keras_layer, etab): next_c = in_transform * next_c + in_gate * _convert_activation(gates[2], keras_layer, None) out_gate = _convert_recurrent_activation(gates[3], keras_layer) next_h = out_gate * _convert_activation(next_c, keras_layer, None) + if keras_layer.return_sequences: + out_list.append(_op.expand_dims(next_h, axis=1)) + out = _op.concatenate(out_list, axis=1) if keras_layer.return_sequences else next_h out_shape = tuple(dim if dim else 1 for dim in _as_list(keras_layer.output_shape)[0]) - out = _op.reshape(next_h, newshape=out_shape) + out = _op.reshape(out, newshape=out_shape) return [out, next_h, next_c] diff --git a/tests/python/frontend/keras/test_forward.py b/tests/python/frontend/keras/test_forward.py index 26bf58cbf384..4dfe89fe40e5 100644 --- a/tests/python/frontend/keras/test_forward.py +++ b/tests/python/frontend/keras/test_forward.py @@ -417,6 +417,17 @@ def test_forward_reuse_layers(self, keras): keras_model = keras.models.Model(data, z) verify_keras_frontend(keras_model) + def test_forward_lstm(self, keras): + data = keras.layers.Input(shape=(10, 32)) + rnn_funcs = [ + keras.layers.LSTM(16), + keras.layers.LSTM(16, return_sequences=True), + ] + for rnn_func in rnn_funcs: + x = rnn_func(data) + keras_model = keras.models.Model(data, x) + verify_keras_frontend(keras_model, need_transpose=False) + def test_forward_rnn(self, keras): data = keras.layers.Input(shape=(1, 32)) rnn_funcs = [ @@ -613,6 +624,7 @@ def test_forward_nested_layers(self, keras): sut.test_forward_multi_inputs(keras=k) sut.test_forward_multi_outputs(keras=k) sut.test_forward_reuse_layers(keras=k) + sut.test_forward_lstm(keras=k) sut.test_forward_rnn(keras=k) sut.test_forward_vgg16(keras=k) sut.test_forward_vgg16(keras=k, layout="NHWC")