diff --git a/python/mxnet/gluon/nn/basic_layers.py b/python/mxnet/gluon/nn/basic_layers.py index efca0c3d2526..76415090ccff 100644 --- a/python/mxnet/gluon/nn/basic_layers.py +++ b/python/mxnet/gluon/nn/basic_layers.py @@ -62,7 +62,7 @@ def __repr__(self): modstr=modstr) def __getitem__(self, key): - return self._children[str(key)] + return list(self._children.values())[key] def __len__(self): return len(self._children) @@ -119,7 +119,7 @@ def __repr__(self): modstr=modstr) def __getitem__(self, key): - return self._children[str(key)] + return list(self._children.values())[key] def __len__(self): return len(self._children) diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py index ca1e121008d8..854e6fe07f18 100644 --- a/tests/python/unittest/test_gluon.py +++ b/tests/python/unittest/test_gluon.py @@ -609,6 +609,23 @@ def __init__(self, **kwargs): model.collect_params() assert len(w) == 0 +def check_sequential(net): + dense1 = gluon.nn.Dense(10) + net.add(dense1) + dense2 = gluon.nn.Dense(10) + net.add(dense2) + dense3 = gluon.nn.Dense(10) + net.add(dense3) + + assert net[1] is dense2 + assert net[-1] is dense3 + slc = net[1:3] + assert len(slc) == 2 and slc[0] is dense2 and slc[1] is dense3 + +@with_seed() +def test_sequential(): + check_sequential(gluon.nn.Sequential()) + check_sequential(gluon.nn.HybridSequential()) @with_seed() def test_sequential_warning():