From ca902ce66ad7f8007d01be74c9f5b93a370867c7 Mon Sep 17 00:00:00 2001 From: Da Zheng Date: Mon, 24 Sep 2018 21:12:26 +0000 Subject: [PATCH] fix benchmark on control flow operators. --- benchmark/python/control_flow/rnn.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/benchmark/python/control_flow/rnn.py b/benchmark/python/control_flow/rnn.py index 8a44a9cab174..08498724b1b4 100644 --- a/benchmark/python/control_flow/rnn.py +++ b/benchmark/python/control_flow/rnn.py @@ -32,6 +32,7 @@ _parser.add_argument('--benchmark', choices=["foreach", "while_loop"], required=True) _parser.add_argument('--warmup_rounds', type=int, default=20) _parser.add_argument('--test_rounds', type=int, default=100) +_parser.add_argument('--gpu', type=bool, default=False) args = _parser.parse_args() @@ -66,8 +67,7 @@ def _func(*states): loop_vars=states, max_iterations=self.length, ) - assert len(out) == 1 - return out[0] + return out def _zeros(shape, ctx): @@ -124,7 +124,9 @@ def main(): cell_types = [gluon.rnn.RNNCell, gluon.rnn.GRUCell, gluon.rnn.LSTMCell] - ctxs = [mx.cpu(0)] + [mx.gpu(i) for i in _get_gpus()] + ctxs = [mx.cpu(0)] + if args.gpu: + ctxs = ctxs + [mx.gpu(i) for i in _get_gpus()] seq_lens = [100] batch_sizes = [1, 32] hidden_dims = [512]