From f59325255baf97e115709f0cc7bcac8288962ed3 Mon Sep 17 00:00:00 2001 From: Rahul Huilgol Date: Thu, 26 Jul 2018 15:33:24 -0700 Subject: [PATCH 1/7] Add description about update on kvstore --- docs/faq/distributed_training.md | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/docs/faq/distributed_training.md b/docs/faq/distributed_training.md index 70078ba60957..1f89a8dff121 100644 --- a/docs/faq/distributed_training.md +++ b/docs/faq/distributed_training.md @@ -73,6 +73,13 @@ These can be passed as arguments to the iterator. You can look at [example/gluon/image_classification.py](https://github.com/apache/incubator-mxnet/blob/master/example/gluon/image_classification.py) to see an example usage. +### Updating weights +KVStore server supports two modes, one which aggregates the gradients and updates the weights using those gradients, and second where the server only aggregates gradients. In the latter case, when a worker process pulls from kvstore, it gets the aggregated gradients. The worker then uses these gradients and applies the weights locally. + +When using Gluon there is an option to choose between these modes by passing `update_on_kvstore` variable when you create the [Trainer](https://mxnet.incubator.apache.org/versions/master/api/python/gluon/gluon.html#mxnet.gluon.Trainer) object. + +When using the symbolic interface, it performs the weight updates on the server without the user having to do anything special. + ### Different Modes of Distributed Training Distributed training itself is enabled when kvstore creation string contains the word `dist`. @@ -86,9 +93,9 @@ In this mode, if a worker crashes, then it halts the progress of all workers. - `dist_async`: In asynchronous distributed training, the server receives gradients from one worker and immediately updates its store, which it uses to respond to any future pulls. This means that a worker who finishes processing a batch can pull the current parameters from server and start the next batch, even if other workers haven't finished processing the earlier batch. -This is faster than `dist_sync` but can take more epochs to converge. -In `async` mode, it is required to pass an optimizer because in the absence of an optimizer kvstore would replace the stored weights with received weights and this doesn't make sense for training in asynchronous mode. +This is faster than `dist_sync` because there is no cost of synchronization, but can take more epochs to converge. The update of weights is atomic, meaning no two updates happen on the same weight at the same time. However, the order of updates is not guaranteed. +In `async` mode, it is required to pass an optimizer because in the absence of an optimizer kvstore would replace the stored weights with received weights and this doesn't make sense for training in asynchronous mode. Hence, when using Gluon with `async` mode we need to set `update_on_kvstore` to `True`. - `dist_sync_device`: Same as `dist_sync` except that when there are multiple GPUs being used on each node, this mode aggregates gradients and updates weights on GPU while dist_sync does so on CPU memory. From 90e60b7809df29cc53f8a38dc6a91abc4dc40269 Mon Sep 17 00:00:00 2001 From: Rahul Date: Thu, 26 Jul 2018 15:55:45 -0700 Subject: [PATCH 2/7] add async check for gluon --- python/mxnet/gluon/trainer.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/python/mxnet/gluon/trainer.py b/python/mxnet/gluon/trainer.py index b4263410a50b..1e74b1558e7c 100644 --- a/python/mxnet/gluon/trainer.py +++ b/python/mxnet/gluon/trainer.py @@ -187,6 +187,10 @@ def _init_kvstore(self): arg_arrays = {param.name: param.data(self._contexts[0]) for param in self._params} kvstore, update_on_kvstore = _create_kvstore(config['kvstore'], len(self._contexts), arg_arrays) + if 'async' in kvstore.type and not config['update_on_kvstore']: + raise ValueError("Please set update_on_kvstore to true " + "when training in async mode.") + if config['update_on_kvstore'] is not None: update_on_kvstore = config['update_on_kvstore'] if kvstore: @@ -195,7 +199,8 @@ def _init_kvstore(self): self._distributed = 'dist' in kvstore.type if self._distributed: # kv.pull(row_sparse_grad) is not supported for dist kvstore - update_on_kvstore = self._contains_sparse_weight or self._contains_sparse_grad + update_on_kvstore = self._contains_sparse_weight or self._contains_sparse_grad \ + or 'async' in kvstore.type if update_on_kvstore: # optimizer preferably needs to be set before init for multiprecision kvstore.set_optimizer(self._optimizer) From 199b0fd269fdc159d6a543671f01938899fe2d69 Mon Sep 17 00:00:00 2001 From: Rahul Date: Thu, 26 Jul 2018 18:21:32 -0700 Subject: [PATCH 3/7] only raise error if user set update_on_kvstore --- python/mxnet/gluon/trainer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/mxnet/gluon/trainer.py b/python/mxnet/gluon/trainer.py index 1e74b1558e7c..9ab5db258ef6 100644 --- a/python/mxnet/gluon/trainer.py +++ b/python/mxnet/gluon/trainer.py @@ -187,7 +187,8 @@ def _init_kvstore(self): arg_arrays = {param.name: param.data(self._contexts[0]) for param in self._params} kvstore, update_on_kvstore = _create_kvstore(config['kvstore'], len(self._contexts), arg_arrays) - if 'async' in kvstore.type and not config['update_on_kvstore']: + if 'async' in kvstore.type and config['update_on_kvstore'] is not None\ + and config['update_on_kvstore']: raise ValueError("Please set update_on_kvstore to true " "when training in async mode.") From 2d89104041db48e4b0db4c371a8da6b0de065dd4 Mon Sep 17 00:00:00 2001 From: Rahul Date: Thu, 26 Jul 2018 19:23:12 -0700 Subject: [PATCH 4/7] fix condition --- python/mxnet/gluon/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/mxnet/gluon/trainer.py b/python/mxnet/gluon/trainer.py index 9ab5db258ef6..5f4a197b763c 100644 --- a/python/mxnet/gluon/trainer.py +++ b/python/mxnet/gluon/trainer.py @@ -188,7 +188,7 @@ def _init_kvstore(self): kvstore, update_on_kvstore = _create_kvstore(config['kvstore'], len(self._contexts), arg_arrays) if 'async' in kvstore.type and config['update_on_kvstore'] is not None\ - and config['update_on_kvstore']: + and not config['update_on_kvstore']: raise ValueError("Please set update_on_kvstore to true " "when training in async mode.") From e20c1436b502ce60f5bd93d344c067e495c7705d Mon Sep 17 00:00:00 2001 From: Rahul Date: Thu, 26 Jul 2018 20:01:08 -0700 Subject: [PATCH 5/7] add async nightly test --- tests/nightly/dist_async_kvstore.py | 48 +++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) create mode 100644 tests/nightly/dist_async_kvstore.py diff --git a/tests/nightly/dist_async_kvstore.py b/tests/nightly/dist_async_kvstore.py new file mode 100644 index 000000000000..3e400eafa045 --- /dev/null +++ b/tests/nightly/dist_async_kvstore.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# pylint: skip-file +import sys +sys.path.insert(0, "../../python/") +import mxnet as mx + +kv = mx.kv.create('dist_async') +my_rank = kv.rank +nworker = kv.num_workers + +def test_gluon_trainer_type(): + def check_trainer_kv_update(update_on_kv): + params = mx.gluon.ParameterDict() + x = params.get('x', shape=(10,1), lr_mult=1.0) + params.initialize(ctx=[mx.cpu(0), mx.cpu(1)], init='zeros') + try: + trainer = mx.gluon.Trainer(params, 'sgd', {'learning_rate': 0.1}, kvstore=kv, update_on_kvstore=update_on_kv) + trainer._init_kvstore() + assert trainer._kv_initialized + assert trainer._update_on_kvstore is True + except ValueError: + assert update_on_kv is False + + check_trainer_kv_update(False) + check_trainer_kv_update(True) + check_trainer_kv_update(None) + print('worker ' + str(my_rank) + ' passed test_gluon_trainer_type') + +if __name__ == "__main__": + test_gluon_trainer_type() \ No newline at end of file From 1ec36ba88d92e2912234e495b8ecc98adc06be7e Mon Sep 17 00:00:00 2001 From: Rahul Date: Thu, 26 Jul 2018 21:26:00 -0700 Subject: [PATCH 6/7] fix case when no kvstore --- python/mxnet/gluon/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/mxnet/gluon/trainer.py b/python/mxnet/gluon/trainer.py index 5f4a197b763c..98a6878b94ba 100644 --- a/python/mxnet/gluon/trainer.py +++ b/python/mxnet/gluon/trainer.py @@ -187,7 +187,7 @@ def _init_kvstore(self): arg_arrays = {param.name: param.data(self._contexts[0]) for param in self._params} kvstore, update_on_kvstore = _create_kvstore(config['kvstore'], len(self._contexts), arg_arrays) - if 'async' in kvstore.type and config['update_on_kvstore'] is not None\ + if kvstore and 'async' in kvstore.type and config['update_on_kvstore'] is not None\ and not config['update_on_kvstore']: raise ValueError("Please set update_on_kvstore to true " "when training in async mode.") From ad57d1ed56e34dceda2b06ad3aa46e2e18816fac Mon Sep 17 00:00:00 2001 From: Rahul Date: Mon, 30 Jul 2018 11:24:47 -0700 Subject: [PATCH 7/7] add example for trainer creation in doc --- docs/faq/distributed_training.md | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/docs/faq/distributed_training.md b/docs/faq/distributed_training.md index 1f89a8dff121..d4fa72db23a0 100644 --- a/docs/faq/distributed_training.md +++ b/docs/faq/distributed_training.md @@ -76,7 +76,17 @@ to see an example usage. ### Updating weights KVStore server supports two modes, one which aggregates the gradients and updates the weights using those gradients, and second where the server only aggregates gradients. In the latter case, when a worker process pulls from kvstore, it gets the aggregated gradients. The worker then uses these gradients and applies the weights locally. -When using Gluon there is an option to choose between these modes by passing `update_on_kvstore` variable when you create the [Trainer](https://mxnet.incubator.apache.org/versions/master/api/python/gluon/gluon.html#mxnet.gluon.Trainer) object. +When using Gluon there is an option to choose between these modes by passing `update_on_kvstore` variable when you create the [Trainer](https://mxnet.incubator.apache.org/versions/master/api/python/gluon/gluon.html#mxnet.gluon.Trainer) object like this: + +``` +trainer = gluon.Trainer(net.collect_params(), optimizer='sgd', + optimizer_params={'learning_rate': opt.lr, + 'wd': opt.wd, + 'momentum': opt.momentum, + 'multi_precision': True}, + kvstore=kv, + update_on_kvstore=True) +``` When using the symbolic interface, it performs the weight updates on the server without the user having to do anything special.