From c1fd7889a6700b005989e8d5bef45c626a45ca64 Mon Sep 17 00:00:00 2001 From: formath Date: Mon, 19 Jun 2017 19:01:50 +0800 Subject: [PATCH 1/2] set __layout__ attr for FucedRNNCell output states --- python/mxnet/rnn/rnn_cell.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/mxnet/rnn/rnn_cell.py b/python/mxnet/rnn/rnn_cell.py index c00f8a39d8c3..31015faf6fcf 100644 --- a/python/mxnet/rnn/rnn_cell.py +++ b/python/mxnet/rnn/rnn_cell.py @@ -675,8 +675,11 @@ def unroll(self, length, inputs, begin_state=None, layout='NTC', merge_outputs=N if not self._get_next_state: outputs, states = rnn, [] elif self._mode == 'lstm': + rnn[1]._set_attr(__layout__='LNC') + rnn[2]._set_attr(__layout__='LNC') outputs, states = rnn[0], [rnn[1], rnn[2]] else: + rnn[1]._set_attr(__layout__='LNC') outputs, states = rnn[0], [rnn[1]] if axis == 1: From c11a64ad66c4b9c7f023b1ea7164e1b7039526de Mon Sep 17 00:00:00 2001 From: formath Date: Mon, 19 Jun 2017 19:29:39 +0800 Subject: [PATCH 2/2] right usage not that in desc --- python/mxnet/rnn/rnn_cell.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/python/mxnet/rnn/rnn_cell.py b/python/mxnet/rnn/rnn_cell.py index 31015faf6fcf..d0505f87ac40 100644 --- a/python/mxnet/rnn/rnn_cell.py +++ b/python/mxnet/rnn/rnn_cell.py @@ -672,14 +672,15 @@ def unroll(self, length, inputs, begin_state=None, layout='NTC', merge_outputs=N mode=self._mode, name=self._prefix+'rnn', **states) + attr = {'__layout__' : 'LNC'} if not self._get_next_state: outputs, states = rnn, [] elif self._mode == 'lstm': - rnn[1]._set_attr(__layout__='LNC') - rnn[2]._set_attr(__layout__='LNC') + rnn[1]._set_attr(**attr) + rnn[2]._set_attr(**attr) outputs, states = rnn[0], [rnn[1], rnn[2]] else: - rnn[1]._set_attr(__layout__='LNC') + rnn[1]._set_attr(**attr) outputs, states = rnn[0], [rnn[1]] if axis == 1: