From 7e244e4d34038c4a2e1965d64ac50f56303a1b8c Mon Sep 17 00:00:00 2001 From: Shaochen Shi Date: Fri, 24 Sep 2021 13:39:10 +0800 Subject: [PATCH] Allow to scale LR in different ways. --- deepmd/train/trainer.py | 13 ++++++++++++- deepmd/utils/argcheck.py | 4 +++- doc/train/parallel-training.md | 14 ++++++++++++-- 3 files changed, 27 insertions(+), 4 deletions(-) diff --git a/deepmd/train/trainer.py b/deepmd/train/trainer.py index 7556cc1ebb..a8d85d86e6 100644 --- a/deepmd/train/trainer.py +++ b/deepmd/train/trainer.py @@ -157,6 +157,13 @@ def _init_param(self, jdata): # learning rate lr_param = j_must_have(jdata, 'learning_rate') + scale_by_worker = lr_param.get('scale_by_worker', 'linear') + if scale_by_worker == 'linear': + self.scale_lr_coef = float(self.run_opt.world_size) + elif scale_by_worker == 'sqrt': + self.scale_lr_coef = np.sqrt(self.run_opt.world_size).real + else: + self.scale_lr_coef = 1. lr_type = lr_param.get('type', 'exp') if lr_type == 'exp': self.lr = LearningRateExp(lr_param['start_lr'], @@ -330,7 +337,11 @@ def _build_network(self, data): def _build_training(self): trainable_variables = tf.trainable_variables() if self.run_opt.is_distrib: - optimizer = tf.train.AdamOptimizer(learning_rate = self.learning_rate*self.run_opt.world_size) + if self.scale_lr_coef > 1.: + log.info('Scale learning rate by coef: %f', self.scale_lr_coef) + optimizer = tf.train.AdamOptimizer(self.learning_rate*self.scale_lr_coef) + else: + optimizer = tf.train.AdamOptimizer(self.learning_rate) optimizer = self.run_opt._HVD.DistributedOptimizer(optimizer) else: optimizer = tf.train.AdamOptimizer(learning_rate = self.learning_rate) diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 8e55187574..0ea84d8e73 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -452,8 +452,10 @@ def learning_rate_variant_type_args(): def learning_rate_args(): + doc_scale_by_worker = 'When parallel training or batch size scaled, how to alter learning rate. Valid values are `linear`(default), `sqrt` or `none`.' doc_lr = "The definitio of learning rate" - return Argument("learning_rate", dict, [], + return Argument("learning_rate", dict, + [Argument("scale_by_worker", str, optional=True, default='linear', doc=doc_scale_by_worker)], [learning_rate_variant_type_args()], doc = doc_lr) diff --git a/doc/train/parallel-training.md b/doc/train/parallel-training.md index bd1bdb889a..b252446971 100644 --- a/doc/train/parallel-training.md +++ b/doc/train/parallel-training.md @@ -3,9 +3,19 @@ Currently, parallel training is enabled in a sychoronized way with help of [Horovod](https://github.com/horovod/horovod). Depend on the number of training processes (according to MPI context) and number of GPU cards avaliable, DeePMD-kit will decide whether to launch the training in parallel (distributed) mode or in serial mode. Therefore, no additional options is specified in your JSON/YAML input file. -Horovod works in the data-parallel mode, resulting in a larger global batch size. For example, the real batch size is 8 when `batch_size` is set to 2 in the input file and you launch 4 workers. Thus, `learning_rate` is automatically scaled by the number of workers for better convergence. The number of decay steps required to achieve same accuracy will also reduce based on the number of cards (e.g., 1/4 of steps in the above case), but needs to be scaled manually in the input file. +## Tuning learning rate -Technical details of such heuristic rule are discussed at [Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour](https://arxiv.org/abs/1706.02677). +Horovod works in the data-parallel mode, resulting in a larger global batch size. For example, the real batch size is 8 when `batch_size` is set to 2 in the input file and you launch 4 workers. Thus, `learning_rate` is automatically scaled by the number of workers for better convergence. Technical details of such heuristic rule are discussed at [Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour](https://arxiv.org/abs/1706.02677). + +The number of decay steps required to achieve same accuracy can decrease by the number of cards (e.g., 1/2 of steps in the above case), but needs to be scaled manually in the input file. + +In some cases, it won't work well when scale learning rate by worker count in a `linear` way. Then you can try `sqrt` or `none` by setting argument `scale_by_worker` like below. +```json + "learning_rate" :{ + "scale_by_worker": "none", + "type": "exp" + } +``` ## Scaling test