diff --git a/benchmark/python/control_flow/rnn.py b/benchmark/python/control_flow/rnn.py index 8a44a9cab174..08498724b1b4 100644 --- a/benchmark/python/control_flow/rnn.py +++ b/benchmark/python/control_flow/rnn.py @@ -32,6 +32,7 @@ _parser.add_argument('--benchmark', choices=["foreach", "while_loop"], required=True) _parser.add_argument('--warmup_rounds', type=int, default=20) _parser.add_argument('--test_rounds', type=int, default=100) +_parser.add_argument('--gpu', type=bool, default=False) args = _parser.parse_args() @@ -66,8 +67,7 @@ def _func(*states): loop_vars=states, max_iterations=self.length, ) - assert len(out) == 1 - return out[0] + return out def _zeros(shape, ctx): @@ -124,7 +124,9 @@ def main(): cell_types = [gluon.rnn.RNNCell, gluon.rnn.GRUCell, gluon.rnn.LSTMCell] - ctxs = [mx.cpu(0)] + [mx.gpu(i) for i in _get_gpus()] + ctxs = [mx.cpu(0)] + if args.gpu: + ctxs = ctxs + [mx.gpu(i) for i in _get_gpus()] seq_lens = [100] batch_sizes = [1, 32] hidden_dims = [512]