From 7e971d2cc583d9fb0f835de5b63b3dd228a6a89a Mon Sep 17 00:00:00 2001 From: Qingchao Shen Date: Wed, 27 Sep 2023 10:21:30 +0800 Subject: [PATCH 1/3] fix bug in gru and simpleRNN about go_backwards --- python/tvm/relay/frontend/keras.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/python/tvm/relay/frontend/keras.py b/python/tvm/relay/frontend/keras.py index 9e09cb400ab2..f6b9c4745a30 100644 --- a/python/tvm/relay/frontend/keras.py +++ b/python/tvm/relay/frontend/keras.py @@ -1060,6 +1060,8 @@ def _convert_simple_rnn( assert units > 0, "The value of units must be a positive integer" if keras_layer.use_bias: in_bias = etab.new_const(weightList[2]) + if keras_layer.go_backwards: + in_data = _op.reverse(in_data, axis=1) assert len(in_data.type_annotation.shape) == 3 timeDim = in_data.type_annotation.shape[1].value in_data_split = _op.split(in_data, indices_or_sections=timeDim, axis=1) @@ -1090,6 +1092,8 @@ def _convert_gru( recurrent_weight = etab.new_const(weightList[1].transpose([1, 0])) if keras_layer.use_bias: in_bias = etab.new_const(weightList[2]) + if keras_layer.go_backwards: + in_data = _op.reverse(in_data, axis=1) units = list(weightList[0].shape)[1] assert units > 0, "The value of units must be a positive integer" in_data = _op.nn.batch_flatten(in_data) From 7281e04b0a586d28d2b494f4d029d4c343523f11 Mon Sep 17 00:00:00 2001 From: Qingchao Shen Date: Wed, 27 Sep 2023 10:24:11 +0800 Subject: [PATCH 2/3] Update test_forward.py --- tests/python/frontend/keras/test_forward.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tests/python/frontend/keras/test_forward.py b/tests/python/frontend/keras/test_forward.py index 9d33b15a9179..43fe6a5fdca3 100644 --- a/tests/python/frontend/keras/test_forward.py +++ b/tests/python/frontend/keras/test_forward.py @@ -568,12 +568,23 @@ def test_forward_rnn(self, keras_mod): keras_mod.layers.SimpleRNN( units=16, return_state=False, activation="tanh", use_bias=False ), + keras_mod.layers.SimpleRNN( + units=16, return_state=False, activation="tanh", go_backwards=True + ), + keras_mod.layers.GRU( + units=16, + return_state=False, + recurrent_activation="sigmoid", + activation="tanh", + reset_after=False, + ), keras_mod.layers.GRU( units=16, return_state=False, recurrent_activation="sigmoid", activation="tanh", reset_after=False, + use_bias=False, ), keras_mod.layers.GRU( units=16, @@ -582,6 +593,7 @@ def test_forward_rnn(self, keras_mod): activation="tanh", reset_after=False, use_bias=False, + go_backwards=True, ), ] for rnn_func in rnn_funcs: From 2ad6cea8404d4fdf8873ef6b2249740b7cdd22dd Mon Sep 17 00:00:00 2001 From: Qingchao Shen Date: Wed, 27 Sep 2023 12:55:36 +0800 Subject: [PATCH 3/3] Update keras.py --- python/tvm/relay/frontend/keras.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/frontend/keras.py b/python/tvm/relay/frontend/keras.py index f6b9c4745a30..6c82ebb427e7 100644 --- a/python/tvm/relay/frontend/keras.py +++ b/python/tvm/relay/frontend/keras.py @@ -1060,10 +1060,10 @@ def _convert_simple_rnn( assert units > 0, "The value of units must be a positive integer" if keras_layer.use_bias: in_bias = etab.new_const(weightList[2]) - if keras_layer.go_backwards: - in_data = _op.reverse(in_data, axis=1) assert len(in_data.type_annotation.shape) == 3 timeDim = in_data.type_annotation.shape[1].value + if keras_layer.go_backwards: + in_data = _op.reverse(in_data, axis=1) in_data_split = _op.split(in_data, indices_or_sections=timeDim, axis=1) for i in range(len(in_data_split)): in_data_split_i = _op.nn.batch_flatten(in_data_split[i])