From 429015f9ad39fa9d2cf413afd53a86d8c2386eda Mon Sep 17 00:00:00 2001 From: Sheng Zha Date: Tue, 15 May 2018 10:55:03 -0700 Subject: [PATCH] fix rnn --- python/mxnet/gluon/rnn/rnn_layer.py | 5 +++-- tests/python/unittest/test_gluon_rnn.py | 29 +++++++++++++++++++------ 2 files changed, 25 insertions(+), 9 deletions(-) diff --git a/python/mxnet/gluon/rnn/rnn_layer.py b/python/mxnet/gluon/rnn/rnn_layer.py index 34ad05d5cc90..89224cf6f9b8 100644 --- a/python/mxnet/gluon/rnn/rnn_layer.py +++ b/python/mxnet/gluon/rnn/rnn_layer.py @@ -23,7 +23,7 @@ from __future__ import print_function __all__ = ['RNN', 'LSTM', 'GRU'] -from ... import ndarray +from ... import ndarray, autograd from .. import Block from . import rnn_cell @@ -185,7 +185,8 @@ def forward(self, inputs, states=None): for i in range(self._dir): self.i2h_weight[i].shape = (self._gates*self._hidden_size, inputs.shape[2]) self.i2h_weight[i]._finish_deferred_init() - if inputs.context.device_type == 'gpu' or self._mode == 'lstm': + if inputs.context.device_type == 'gpu' or \ + self._mode == 'lstm' and not (self._dropout and autograd.is_training()): out = self._forward_kernel(inputs, states) else: out = self._forward(inputs, states) diff --git a/tests/python/unittest/test_gluon_rnn.py b/tests/python/unittest/test_gluon_rnn.py index f22b13d65752..24d5a932d7b2 100644 --- a/tests/python/unittest/test_gluon_rnn.py +++ b/tests/python/unittest/test_gluon_rnn.py @@ -80,7 +80,7 @@ def test_lstm_cpu_inference(): mx.test_utils.assert_almost_equal(y, EXPECTED_LSTM_OUTPUT, rtol=1e-3, atol=1e-5) - + def test_gru(): cell = gluon.rnn.GRUCell(100, prefix='rnn_') @@ -242,7 +242,7 @@ def test_rnn_cells(): net.add(gluon.rnn.GRUCell(100, input_size=100)) check_rnn_forward(net, mx.nd.ones((8, 3, 200))) -def check_rnn_layer_forward(layer, inputs, states=None): +def check_rnn_layer_forward(layer, inputs, states=None, run_only=False): layer.collect_params().initialize() inputs.attach_grad() with mx.autograd.record(): @@ -268,17 +268,32 @@ def check_rnn_layer_forward(layer, inputs, states=None): assert isinstance(out, mx.nd.NDArray) out.backward() - mx.test_utils.assert_almost_equal(np_out, out.asnumpy(), rtol=1e-3, atol=1e-5) - mx.test_utils.assert_almost_equal(np_dx, inputs.grad.asnumpy(), rtol=1e-3, atol=1e-5) + if not run_only: + mx.test_utils.assert_almost_equal(np_out, out.asnumpy(), rtol=1e-3, atol=1e-5) + mx.test_utils.assert_almost_equal(np_dx, inputs.grad.asnumpy(), rtol=1e-3, atol=1e-5) def test_rnn_layers(): check_rnn_layer_forward(gluon.rnn.RNN(10, 2), mx.nd.ones((8, 3, 20))) - check_rnn_layer_forward(gluon.rnn.RNN(10, 2), mx.nd.ones((8, 3, 20)), mx.nd.ones((2, 3, 10))) + check_rnn_layer_forward(gluon.rnn.RNN(10, 2, bidirectional=True), mx.nd.ones((8, 3, 20)), mx.nd.ones((4, 3, 10))) check_rnn_layer_forward(gluon.rnn.LSTM(10, 2), mx.nd.ones((8, 3, 20))) - check_rnn_layer_forward(gluon.rnn.LSTM(10, 2), mx.nd.ones((8, 3, 20)), [mx.nd.ones((2, 3, 10)), mx.nd.ones((2, 3, 10))]) + check_rnn_layer_forward(gluon.rnn.LSTM(10, 2, bidirectional=True), mx.nd.ones((8, 3, 20)), [mx.nd.ones((4, 3, 10)), mx.nd.ones((4, 3, 10))]) check_rnn_layer_forward(gluon.rnn.GRU(10, 2), mx.nd.ones((8, 3, 20))) - check_rnn_layer_forward(gluon.rnn.GRU(10, 2), mx.nd.ones((8, 3, 20)), mx.nd.ones((2, 3, 10))) + check_rnn_layer_forward(gluon.rnn.GRU(10, 2, bidirectional=True), mx.nd.ones((8, 3, 20)), mx.nd.ones((4, 3, 10))) + + check_rnn_layer_forward(gluon.rnn.RNN(10, 2, dropout=0.5), mx.nd.ones((8, 3, 20)), + run_only=True) + check_rnn_layer_forward(gluon.rnn.RNN(10, 2, bidirectional=True, dropout=0.5), + mx.nd.ones((8, 3, 20)), mx.nd.ones((4, 3, 10)), run_only=True) + check_rnn_layer_forward(gluon.rnn.LSTM(10, 2, dropout=0.5), mx.nd.ones((8, 3, 20)), + run_only=True) + check_rnn_layer_forward(gluon.rnn.LSTM(10, 2, bidirectional=True, dropout=0.5), + mx.nd.ones((8, 3, 20)), + [mx.nd.ones((4, 3, 10)), mx.nd.ones((4, 3, 10))], run_only=True) + check_rnn_layer_forward(gluon.rnn.GRU(10, 2, dropout=0.5), mx.nd.ones((8, 3, 20)), + run_only=True) + check_rnn_layer_forward(gluon.rnn.GRU(10, 2, bidirectional=True, dropout=0.5), + mx.nd.ones((8, 3, 20)), mx.nd.ones((4, 3, 10)), run_only=True) net = gluon.nn.Sequential() net.add(gluon.rnn.LSTM(10, 2, bidirectional=True))